dify
This commit is contained in:
0
dify/api/libs/__init__.py
Normal file
0
dify/api/libs/__init__.py
Normal file
134
dify/api/libs/broadcast_channel/channel.py
Normal file
134
dify/api/libs/broadcast_channel/channel.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Broadcast channel for Pub/Sub messaging.
|
||||
"""
|
||||
|
||||
import types
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Protocol, Self
|
||||
|
||||
|
||||
class Subscription(AbstractContextManager["Subscription"], Protocol):
|
||||
"""A subscription to a topic that provides an iterator over received messages.
|
||||
The subscription can be used as a context manager and will automatically
|
||||
close when exiting the context.
|
||||
|
||||
Note: `Subscription` instances are not thread-safe. Each thread should create its own
|
||||
subscription.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
"""`__iter__` returns an iterator used to consume the message from this subscription.
|
||||
|
||||
If the caller did not enter the context, `__iter__` may lazily perform the setup before
|
||||
yielding messages; otherwise `__enter__` handles it.”
|
||||
|
||||
If the subscription is closed, then the returned iterator exits without
|
||||
raising any error.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""close closes the subscription, releases any resources associated with it."""
|
||||
...
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""`__enter__` does the setup logic of the subscription (if any), and return itself."""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
self.close()
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||
"""Receive the next message from the broadcast channel.
|
||||
|
||||
If `timeout` is specified, this method returns `None` if no message is
|
||||
received within the given period. If `timeout` is `None`, the call blocks
|
||||
until a message is received.
|
||||
|
||||
Calling receive with `timeout=None` is highly discouraged, as it is impossible to
|
||||
cancel a blocking subscription.
|
||||
|
||||
:param timeout: timeout for receive message, in seconds.
|
||||
|
||||
Returns:
|
||||
bytes: The received message as a byte string, or
|
||||
None: If the timeout expires before a message is received.
|
||||
|
||||
Raises:
|
||||
SubscriptionClosed: If the subscription has already been closed.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Producer(Protocol):
|
||||
"""Producer is an interface for message publishing. It is already bound to a specific topic.
|
||||
|
||||
`Producer` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def publish(self, payload: bytes) -> None:
|
||||
"""Publish a message to the bounded topic."""
|
||||
...
|
||||
|
||||
|
||||
class Subscriber(Protocol):
|
||||
"""Subscriber is an interface for subscription creation. It is already bound to a specific topic.
|
||||
|
||||
`Subscriber` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def subscribe(self) -> Subscription:
|
||||
pass
|
||||
|
||||
|
||||
class Topic(Producer, Subscriber, Protocol):
|
||||
"""A named channel for publishing and subscribing to messages.
|
||||
|
||||
Topics provide both read and write access. For restricted access,
|
||||
use as_producer() for write-only view or as_subscriber() for read-only view.
|
||||
|
||||
`Topic` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def as_producer(self) -> Producer:
|
||||
"""as_producer creates a write-only view for this topic."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
"""as_subscriber create a read-only view for this topic."""
|
||||
...
|
||||
|
||||
|
||||
class BroadcastChannel(Protocol):
|
||||
"""A broadcasting channel is a channel supporting broadcasting semantics.
|
||||
|
||||
Each channel is identified by a topic, different topics are isolated and do not affect each other.
|
||||
|
||||
There can be multiple subscriptions to a specific topic. When a publisher publishes a message to
|
||||
a specific topic, all subscription should receive the published message.
|
||||
|
||||
There are no restriction for the persistence of messages. Once a subscription is created, it
|
||||
should receive all subsequent messages published.
|
||||
|
||||
`BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def topic(self, topic: str) -> "Topic":
|
||||
"""topic returns a `Topic` instance for the given topic name."""
|
||||
...
|
||||
12
dify/api/libs/broadcast_channel/exc.py
Normal file
12
dify/api/libs/broadcast_channel/exc.py
Normal file
@@ -0,0 +1,12 @@
|
||||
class BroadcastChannelError(Exception):
|
||||
"""`BroadcastChannelError` is the base class for all exceptions related
|
||||
to `BroadcastChannel`."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SubscriptionClosedError(BroadcastChannelError):
|
||||
"""SubscriptionClosedError means that the subscription has been closed and
|
||||
methods for consuming messages should not be called."""
|
||||
|
||||
pass
|
||||
4
dify/api/libs/broadcast_channel/redis/__init__.py
Normal file
4
dify/api/libs/broadcast_channel/redis/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .channel import BroadcastChannel
|
||||
from .sharded_channel import ShardedRedisBroadcastChannel
|
||||
|
||||
__all__ = ["BroadcastChannel", "ShardedRedisBroadcastChannel"]
|
||||
227
dify/api/libs/broadcast_channel/redis/_subscription.py
Normal file
227
dify/api/libs/broadcast_channel/redis/_subscription.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import types
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Self
|
||||
|
||||
from libs.broadcast_channel.channel import Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis.client import PubSub
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisSubscriptionBase(Subscription):
|
||||
"""Base class for Redis pub/sub subscriptions with common functionality.
|
||||
|
||||
This class provides shared functionality for both regular and sharded
|
||||
Redis pub/sub subscriptions, reducing code duplication and improving
|
||||
maintainability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._pubsub: PubSub | None = pubsub
|
||||
self._topic = topic
|
||||
self._closed = threading.Event()
|
||||
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
|
||||
self._dropped_count = 0
|
||||
self._listener_thread: threading.Thread | None = None
|
||||
self._start_lock = threading.Lock()
|
||||
self._started = False
|
||||
|
||||
def _start_if_needed(self) -> None:
|
||||
"""Start the subscription if not already started."""
|
||||
with self._start_lock:
|
||||
if self._started:
|
||||
return
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
if self._pubsub is None:
|
||||
raise SubscriptionClosedError(
|
||||
f"The Redis {self._get_subscription_type()} subscription has been cleaned up"
|
||||
)
|
||||
|
||||
self._subscribe()
|
||||
_logger.debug("Subscribed to %s channel %s", self._get_subscription_type(), self._topic)
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self._listen,
|
||||
name=f"redis-{self._get_subscription_type().replace(' ', '-')}-broadcast-{self._topic}",
|
||||
daemon=True,
|
||||
)
|
||||
self._listener_thread.start()
|
||||
self._started = True
|
||||
|
||||
def _listen(self) -> None:
|
||||
"""Main listener loop for processing messages."""
|
||||
pubsub = self._pubsub
|
||||
assert pubsub is not None, "PubSub should not be None while starting listening."
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
raw_message = self._get_message()
|
||||
except Exception as e:
|
||||
# Log the exception and exit the listener thread gracefully
|
||||
# This handles Redis connection errors and other exceptions
|
||||
_logger.error(
|
||||
"Error getting message from Redis %s subscription, topic=%s: %s",
|
||||
self._get_subscription_type(),
|
||||
self._topic,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
break
|
||||
|
||||
if raw_message is None:
|
||||
continue
|
||||
|
||||
if raw_message.get("type") != self._get_message_type():
|
||||
continue
|
||||
|
||||
channel_field = raw_message.get("channel")
|
||||
if isinstance(channel_field, bytes):
|
||||
channel_name = channel_field.decode("utf-8")
|
||||
elif isinstance(channel_field, str):
|
||||
channel_name = channel_field
|
||||
else:
|
||||
channel_name = str(channel_field)
|
||||
|
||||
if channel_name != self._topic:
|
||||
_logger.warning(
|
||||
"Ignoring %s message from unexpected channel %s", self._get_subscription_type(), channel_name
|
||||
)
|
||||
continue
|
||||
|
||||
payload_bytes: bytes | None = raw_message.get("data")
|
||||
if not isinstance(payload_bytes, bytes):
|
||||
_logger.error(
|
||||
"Received invalid data from %s channel %s, type=%s",
|
||||
self._get_subscription_type(),
|
||||
self._topic,
|
||||
type(payload_bytes),
|
||||
)
|
||||
continue
|
||||
|
||||
self._enqueue_message(payload_bytes)
|
||||
|
||||
_logger.debug("%s listener thread stopped for channel %s", self._get_subscription_type().title(), self._topic)
|
||||
try:
|
||||
self._unsubscribe()
|
||||
pubsub.close()
|
||||
_logger.debug("%s PubSub closed for topic %s", self._get_subscription_type().title(), self._topic)
|
||||
except Exception as e:
|
||||
_logger.error(
|
||||
"Error during cleanup of Redis %s subscription, topic=%s: %s",
|
||||
self._get_subscription_type(),
|
||||
self._topic,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
self._pubsub = None
|
||||
|
||||
def _enqueue_message(self, payload: bytes) -> None:
|
||||
"""Enqueue a message to the internal queue with dropping behavior."""
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
self._queue.put_nowait(payload)
|
||||
return
|
||||
except queue.Full:
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
self._dropped_count += 1
|
||||
_logger.debug(
|
||||
"Dropped message from Redis %s subscription, topic=%s, total_dropped=%d",
|
||||
self._get_subscription_type(),
|
||||
self._topic,
|
||||
self._dropped_count,
|
||||
)
|
||||
except queue.Empty:
|
||||
continue
|
||||
return
|
||||
|
||||
def _message_iterator(self) -> Generator[bytes, None, None]:
|
||||
"""Iterator for consuming messages from the subscription."""
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
item = self._queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
yield item
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
"""Return an iterator over messages from the subscription."""
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
self._start_if_needed()
|
||||
return iter(self._message_iterator())
|
||||
|
||||
def receive(self, timeout: float | None = None) -> bytes | None:
|
||||
"""Receive the next message from the subscription."""
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
self._start_if_needed()
|
||||
|
||||
try:
|
||||
item = self._queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
return item
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""Context manager entry point."""
|
||||
self._start_if_needed()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
"""Context manager exit point."""
|
||||
self.close()
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the subscription and clean up resources."""
|
||||
if self._closed.is_set():
|
||||
return
|
||||
|
||||
self._closed.set()
|
||||
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the
|
||||
# message retrieval method should NOT be called concurrently.
|
||||
#
|
||||
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
|
||||
listener = self._listener_thread
|
||||
if listener is not None:
|
||||
listener.join(timeout=1.0)
|
||||
self._listener_thread = None
|
||||
|
||||
# Abstract methods to be implemented by subclasses
|
||||
def _get_subscription_type(self) -> str:
|
||||
"""Return the subscription type (e.g., 'regular' or 'sharded')."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _subscribe(self) -> None:
|
||||
"""Subscribe to the Redis topic using the appropriate command."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _unsubscribe(self) -> None:
|
||||
"""Unsubscribe from the Redis topic using the appropriate command."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_message(self) -> dict | None:
|
||||
"""Get a message from Redis using the appropriate method."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
"""Return the expected message type (e.g., 'message' or 'smessage')."""
|
||||
raise NotImplementedError
|
||||
67
dify/api/libs/broadcast_channel/redis/channel.py
Normal file
67
dify/api/libs/broadcast_channel/redis/channel.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis
|
||||
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
|
||||
class BroadcastChannel:
|
||||
"""
|
||||
Redis Pub/Sub based broadcast channel implementation (regular, non-sharded).
|
||||
|
||||
Provides "at most once" delivery semantics for messages published to channels
|
||||
using Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
|
||||
|
||||
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis,
|
||||
):
|
||||
self._client = redis_client
|
||||
|
||||
def topic(self, topic: str) -> "Topic":
|
||||
return Topic(self._client, topic)
|
||||
|
||||
|
||||
class Topic:
|
||||
def __init__(self, redis_client: Redis, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.publish(self._topic, payload)
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _RedisSubscription(
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
)
|
||||
|
||||
|
||||
class _RedisSubscription(RedisSubscriptionBase):
|
||||
"""Regular Redis pub/sub subscription implementation."""
|
||||
|
||||
def _get_subscription_type(self) -> str:
|
||||
return "regular"
|
||||
|
||||
def _subscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.subscribe(self._topic)
|
||||
|
||||
def _unsubscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.unsubscribe(self._topic)
|
||||
|
||||
def _get_message(self) -> dict | None:
|
||||
assert self._pubsub is not None
|
||||
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
return "message"
|
||||
65
dify/api/libs/broadcast_channel/redis/sharded_channel.py
Normal file
65
dify/api/libs/broadcast_channel/redis/sharded_channel.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis
|
||||
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
|
||||
class ShardedRedisBroadcastChannel:
|
||||
"""
|
||||
Redis 7.0+ Sharded Pub/Sub based broadcast channel implementation.
|
||||
|
||||
Provides "at most once" delivery semantics using SPUBLISH/SSUBSCRIBE commands,
|
||||
distributing channels across Redis cluster nodes for better scalability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis,
|
||||
):
|
||||
self._client = redis_client
|
||||
|
||||
def topic(self, topic: str) -> "ShardedTopic":
|
||||
return ShardedTopic(self._client, topic)
|
||||
|
||||
|
||||
class ShardedTopic:
|
||||
def __init__(self, redis_client: Redis, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.spublish(self._topic, payload) # type: ignore[attr-defined]
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _RedisShardedSubscription(
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
)
|
||||
|
||||
|
||||
class _RedisShardedSubscription(RedisSubscriptionBase):
|
||||
"""Redis 7.0+ sharded pub/sub subscription implementation."""
|
||||
|
||||
def _get_subscription_type(self) -> str:
|
||||
return "sharded"
|
||||
|
||||
def _subscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.ssubscribe(self._topic) # type: ignore[attr-defined]
|
||||
|
||||
def _unsubscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.sunsubscribe(self._topic) # type: ignore[attr-defined]
|
||||
|
||||
def _get_message(self) -> dict | None:
|
||||
assert self._pubsub is not None
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=0.1) # type: ignore[attr-defined]
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
return "smessage"
|
||||
14
dify/api/libs/collection_utils.py
Normal file
14
dify/api/libs/collection_utils.py
Normal file
@@ -0,0 +1,14 @@
|
||||
def convert_to_lower_and_upper_set(inputs: list[str] | set[str]) -> set[str]:
|
||||
"""
|
||||
Convert a list or set of strings to a set containing both lower and upper case versions of each string.
|
||||
|
||||
Args:
|
||||
inputs (list[str] | set[str]): A list or set of strings to be converted.
|
||||
|
||||
Returns:
|
||||
set[str]: A set containing both lower and upper case versions of each string.
|
||||
"""
|
||||
if not inputs:
|
||||
return set()
|
||||
else:
|
||||
return {case for s in inputs if s for case in (s.lower(), s.upper())}
|
||||
32
dify/api/libs/custom_inputs.py
Normal file
32
dify/api/libs/custom_inputs.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Custom input types for Flask-RESTX request parsing."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def time_duration(value: str) -> str:
|
||||
"""
|
||||
Validate and return time duration string.
|
||||
|
||||
Accepts formats: <number>d (days), <number>h (hours), <number>m (minutes), <number>s (seconds)
|
||||
Examples: 7d, 4h, 30m, 30s
|
||||
|
||||
Args:
|
||||
value: The time duration string
|
||||
|
||||
Returns:
|
||||
The validated time duration string
|
||||
|
||||
Raises:
|
||||
ValueError: If the format is invalid
|
||||
"""
|
||||
if not value:
|
||||
raise ValueError("Time duration cannot be empty")
|
||||
|
||||
pattern = r"^(\d+)([dhms])$"
|
||||
if not re.match(pattern, value.lower()):
|
||||
raise ValueError(
|
||||
"Invalid time duration format. Use: <number>d (days), <number>h (hours), "
|
||||
"<number>m (minutes), or <number>s (seconds). Examples: 7d, 4h, 30m, 30s"
|
||||
)
|
||||
|
||||
return value.lower()
|
||||
83
dify/api/libs/datetime_utils.py
Normal file
83
dify/api/libs/datetime_utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import abc
|
||||
import datetime
|
||||
from typing import Protocol
|
||||
|
||||
import pytz
|
||||
|
||||
|
||||
class _NowFunction(Protocol):
|
||||
@abc.abstractmethod
|
||||
def __call__(self, tz: datetime.timezone | None) -> datetime.datetime:
|
||||
pass
|
||||
|
||||
|
||||
# _now_func is a callable with the _NowFunction signature.
|
||||
# Its sole purpose is to abstract time retrieval, enabling
|
||||
# developers to mock this behavior in tests and time-dependent scenarios.
|
||||
_now_func: _NowFunction = datetime.datetime.now
|
||||
|
||||
|
||||
def naive_utc_now() -> datetime.datetime:
|
||||
"""Return a naive datetime object (without timezone information)
|
||||
representing current UTC time.
|
||||
"""
|
||||
return _now_func(datetime.UTC).replace(tzinfo=None)
|
||||
|
||||
|
||||
def ensure_naive_utc(dt: datetime.datetime) -> datetime.datetime:
|
||||
"""Return the datetime as naive UTC (tzinfo=None).
|
||||
|
||||
If the input is timezone-aware, convert to UTC and drop the tzinfo.
|
||||
Assumes naive datetimes are already expressed in UTC.
|
||||
"""
|
||||
if dt.tzinfo is None:
|
||||
return dt
|
||||
return dt.astimezone(datetime.UTC).replace(tzinfo=None)
|
||||
|
||||
|
||||
def parse_time_range(
|
||||
start: str | None, end: str | None, tzname: str
|
||||
) -> tuple[datetime.datetime | None, datetime.datetime | None]:
|
||||
"""
|
||||
Parse time range strings and convert to UTC datetime objects.
|
||||
Handles DST ambiguity and non-existent times gracefully.
|
||||
|
||||
Args:
|
||||
start: Start time string (YYYY-MM-DD HH:MM)
|
||||
end: End time string (YYYY-MM-DD HH:MM)
|
||||
tzname: Timezone name
|
||||
|
||||
Returns:
|
||||
tuple: (start_datetime_utc, end_datetime_utc)
|
||||
|
||||
Raises:
|
||||
ValueError: When time range is invalid or start > end
|
||||
"""
|
||||
tz = pytz.timezone(tzname)
|
||||
utc = pytz.utc
|
||||
|
||||
def _parse(time_str: str | None, label: str) -> datetime.datetime | None:
|
||||
if not time_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
dt = datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M").replace(second=0)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid {label} time format: {e}")
|
||||
|
||||
try:
|
||||
return tz.localize(dt, is_dst=None).astimezone(utc)
|
||||
except pytz.AmbiguousTimeError:
|
||||
return tz.localize(dt, is_dst=False).astimezone(utc)
|
||||
except pytz.NonExistentTimeError:
|
||||
dt += datetime.timedelta(hours=1)
|
||||
return tz.localize(dt, is_dst=None).astimezone(utc)
|
||||
|
||||
start_dt = _parse(start, "start")
|
||||
end_dt = _parse(end, "end")
|
||||
|
||||
# Range validation
|
||||
if start_dt and end_dt and start_dt > end_dt:
|
||||
raise ValueError("start must be earlier than or equal to end")
|
||||
|
||||
return start_dt, end_dt
|
||||
604
dify/api/libs/email_i18n.py
Normal file
604
dify/api/libs/email_i18n.py
Normal file
@@ -0,0 +1,604 @@
|
||||
"""
|
||||
Email Internationalization Module
|
||||
|
||||
This module provides a centralized, elegant way to handle email internationalization
|
||||
in Dify. It follows Domain-Driven Design principles with proper type hints and
|
||||
eliminates the need for repetitive language switching logic.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Protocol
|
||||
|
||||
from flask import render_template
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from extensions.ext_mail import mail
|
||||
from services.feature_service import BrandingModel, FeatureService
|
||||
|
||||
|
||||
class EmailType(StrEnum):
|
||||
"""Enumeration of supported email types."""
|
||||
|
||||
RESET_PASSWORD = auto()
|
||||
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST = auto()
|
||||
INVITE_MEMBER = auto()
|
||||
EMAIL_CODE_LOGIN = auto()
|
||||
CHANGE_EMAIL_OLD = auto()
|
||||
CHANGE_EMAIL_NEW = auto()
|
||||
CHANGE_EMAIL_COMPLETED = auto()
|
||||
OWNER_TRANSFER_CONFIRM = auto()
|
||||
OWNER_TRANSFER_OLD_NOTIFY = auto()
|
||||
OWNER_TRANSFER_NEW_NOTIFY = auto()
|
||||
ACCOUNT_DELETION_SUCCESS = auto()
|
||||
ACCOUNT_DELETION_VERIFICATION = auto()
|
||||
ENTERPRISE_CUSTOM = auto()
|
||||
QUEUE_MONITOR_ALERT = auto()
|
||||
DOCUMENT_CLEAN_NOTIFY = auto()
|
||||
EMAIL_REGISTER = auto()
|
||||
EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto()
|
||||
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto()
|
||||
TRIGGER_EVENTS_LIMIT_SANDBOX = auto()
|
||||
TRIGGER_EVENTS_LIMIT_PROFESSIONAL = auto()
|
||||
TRIGGER_EVENTS_USAGE_WARNING_SANDBOX = auto()
|
||||
TRIGGER_EVENTS_USAGE_WARNING_PROFESSIONAL = auto()
|
||||
API_RATE_LIMIT_LIMIT_SANDBOX = auto()
|
||||
API_RATE_LIMIT_WARNING_SANDBOX = auto()
|
||||
|
||||
|
||||
class EmailLanguage(StrEnum):
|
||||
"""Supported email languages with fallback handling."""
|
||||
|
||||
EN_US = "en-US"
|
||||
ZH_HANS = "zh-Hans"
|
||||
|
||||
@classmethod
|
||||
def from_language_code(cls, language_code: str) -> "EmailLanguage":
|
||||
"""Convert a language code to EmailLanguage with fallback to English."""
|
||||
if language_code == "zh-Hans":
|
||||
return cls.ZH_HANS
|
||||
return cls.EN_US
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EmailTemplate:
|
||||
"""Immutable value object representing an email template configuration."""
|
||||
|
||||
subject: str
|
||||
template_path: str
|
||||
branded_template_path: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EmailContent:
|
||||
"""Immutable value object containing rendered email content."""
|
||||
|
||||
subject: str
|
||||
html_content: str
|
||||
template_context: dict[str, Any]
|
||||
|
||||
|
||||
class EmailI18nConfig(BaseModel):
|
||||
"""Configuration for email internationalization."""
|
||||
|
||||
model_config = {"frozen": True, "extra": "forbid"}
|
||||
|
||||
templates: dict[EmailType, dict[EmailLanguage, EmailTemplate]] = Field(
|
||||
default_factory=dict, description="Mapping of email types to language-specific templates"
|
||||
)
|
||||
|
||||
def get_template(self, email_type: EmailType, language: EmailLanguage) -> EmailTemplate:
|
||||
"""Get template configuration for specific email type and language."""
|
||||
type_templates = self.templates.get(email_type)
|
||||
if not type_templates:
|
||||
raise ValueError(f"No templates configured for email type: {email_type}")
|
||||
|
||||
template = type_templates.get(language)
|
||||
if not template:
|
||||
# Fallback to English if specific language not found
|
||||
template = type_templates.get(EmailLanguage.EN_US)
|
||||
if not template:
|
||||
raise ValueError(f"No template found for {email_type} in {language} or English")
|
||||
|
||||
return template
|
||||
|
||||
|
||||
class EmailRenderer(Protocol):
|
||||
"""Protocol for email template renderers."""
|
||||
|
||||
def render_template(self, template_path: str, **context: Any) -> str:
|
||||
"""Render email template with given context."""
|
||||
...
|
||||
|
||||
|
||||
class FlaskEmailRenderer:
|
||||
"""Flask-based email template renderer."""
|
||||
|
||||
def render_template(self, template_path: str, **context: Any) -> str:
|
||||
"""Render email template using Flask's render_template."""
|
||||
return render_template(template_path, **context)
|
||||
|
||||
|
||||
class BrandingService(Protocol):
|
||||
"""Protocol for branding service abstraction."""
|
||||
|
||||
def get_branding_config(self) -> BrandingModel:
|
||||
"""Get current branding configuration."""
|
||||
...
|
||||
|
||||
|
||||
class FeatureBrandingService:
|
||||
"""Feature service based branding implementation."""
|
||||
|
||||
def get_branding_config(self) -> BrandingModel:
|
||||
"""Get branding configuration from feature service."""
|
||||
return FeatureService.get_system_features().branding
|
||||
|
||||
|
||||
class EmailSender(Protocol):
|
||||
"""Protocol for email sending abstraction."""
|
||||
|
||||
def send_email(self, to: str, subject: str, html_content: str):
|
||||
"""Send email with given parameters."""
|
||||
...
|
||||
|
||||
|
||||
class FlaskMailSender:
|
||||
"""Flask-Mail based email sender."""
|
||||
|
||||
def send_email(self, to: str, subject: str, html_content: str):
|
||||
"""Send email using Flask-Mail."""
|
||||
if mail.is_inited():
|
||||
mail.send(to=to, subject=subject, html=html_content)
|
||||
|
||||
|
||||
class EmailI18nService:
|
||||
"""
|
||||
Main service for internationalized email handling.
|
||||
|
||||
This service provides a clean API for sending internationalized emails
|
||||
with proper branding support and template management.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: EmailI18nConfig,
|
||||
renderer: EmailRenderer,
|
||||
branding_service: BrandingService,
|
||||
sender: EmailSender,
|
||||
):
|
||||
self._config = config
|
||||
self._renderer = renderer
|
||||
self._branding_service = branding_service
|
||||
self._sender = sender
|
||||
|
||||
def send_email(
|
||||
self,
|
||||
email_type: EmailType,
|
||||
language_code: str,
|
||||
to: str,
|
||||
template_context: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Send internationalized email with branding support.
|
||||
|
||||
Args:
|
||||
email_type: Type of email to send
|
||||
language_code: Target language code
|
||||
to: Recipient email address
|
||||
template_context: Additional context for template rendering
|
||||
"""
|
||||
if template_context is None:
|
||||
template_context = {}
|
||||
|
||||
language = EmailLanguage.from_language_code(language_code)
|
||||
email_content = self._render_email_content(email_type, language, template_context)
|
||||
|
||||
self._sender.send_email(to=to, subject=email_content.subject, html_content=email_content.html_content)
|
||||
|
||||
def send_change_email(
|
||||
self,
|
||||
language_code: str,
|
||||
to: str,
|
||||
code: str,
|
||||
phase: str,
|
||||
):
|
||||
"""
|
||||
Send change email notification with phase-specific handling.
|
||||
|
||||
Args:
|
||||
language_code: Target language code
|
||||
to: Recipient email address
|
||||
code: Verification code
|
||||
phase: Either 'old_email' or 'new_email'
|
||||
"""
|
||||
if phase == "old_email":
|
||||
email_type = EmailType.CHANGE_EMAIL_OLD
|
||||
elif phase == "new_email":
|
||||
email_type = EmailType.CHANGE_EMAIL_NEW
|
||||
else:
|
||||
raise ValueError(f"Invalid phase: {phase}. Must be 'old_email' or 'new_email'")
|
||||
|
||||
self.send_email(
|
||||
email_type=email_type,
|
||||
language_code=language_code,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"code": code,
|
||||
},
|
||||
)
|
||||
|
||||
def send_raw_email(
|
||||
self,
|
||||
to: str | list[str],
|
||||
subject: str,
|
||||
html_content: str,
|
||||
):
|
||||
"""
|
||||
Send a raw email directly without template processing.
|
||||
|
||||
This method is provided for backward compatibility with legacy email
|
||||
sending that uses pre-rendered HTML content (e.g., enterprise emails
|
||||
with custom templates).
|
||||
|
||||
Args:
|
||||
to: Recipient email address(es)
|
||||
subject: Email subject
|
||||
html_content: Pre-rendered HTML content
|
||||
"""
|
||||
if isinstance(to, list):
|
||||
for recipient in to:
|
||||
self._sender.send_email(to=recipient, subject=subject, html_content=html_content)
|
||||
else:
|
||||
self._sender.send_email(to=to, subject=subject, html_content=html_content)
|
||||
|
||||
def _render_email_content(
|
||||
self,
|
||||
email_type: EmailType,
|
||||
language: EmailLanguage,
|
||||
template_context: dict[str, Any],
|
||||
) -> EmailContent:
|
||||
"""Render email content with branding and internationalization."""
|
||||
template_config = self._config.get_template(email_type, language)
|
||||
branding = self._branding_service.get_branding_config()
|
||||
|
||||
# Determine template path based on branding
|
||||
template_path = template_config.branded_template_path if branding.enabled else template_config.template_path
|
||||
|
||||
# Prepare template context with branding information
|
||||
full_context = {
|
||||
**template_context,
|
||||
"branding_enabled": branding.enabled,
|
||||
"application_title": branding.application_title if branding.enabled else "Dify",
|
||||
}
|
||||
|
||||
# Render template
|
||||
html_content = self._renderer.render_template(template_path, **full_context)
|
||||
|
||||
# Apply templating to subject with all context variables
|
||||
subject = template_config.subject
|
||||
try:
|
||||
subject = subject.format(**full_context)
|
||||
except KeyError:
|
||||
# If template variables are missing, fall back to basic formatting
|
||||
if branding.enabled and "{application_title}" in subject:
|
||||
subject = subject.format(application_title=branding.application_title)
|
||||
|
||||
return EmailContent(
|
||||
subject=subject,
|
||||
html_content=html_content,
|
||||
template_context=full_context,
|
||||
)
|
||||
|
||||
|
||||
def create_default_email_config() -> EmailI18nConfig:
|
||||
"""Create default email i18n configuration with all supported templates."""
|
||||
templates: dict[EmailType, dict[EmailLanguage, EmailTemplate]] = {
|
||||
EmailType.RESET_PASSWORD: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Set Your {application_title} Password",
|
||||
template_path="reset_password_mail_template_en-US.html",
|
||||
branded_template_path="without-brand/reset_password_mail_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="设置您的 {application_title} 密码",
|
||||
template_path="reset_password_mail_template_zh-CN.html",
|
||||
branded_template_path="without-brand/reset_password_mail_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.INVITE_MEMBER: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Join {application_title} Workspace Now",
|
||||
template_path="invite_member_mail_template_en-US.html",
|
||||
branded_template_path="without-brand/invite_member_mail_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="立即加入 {application_title} 工作空间",
|
||||
template_path="invite_member_mail_template_zh-CN.html",
|
||||
branded_template_path="without-brand/invite_member_mail_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.EMAIL_CODE_LOGIN: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="{application_title} Login Code",
|
||||
template_path="email_code_login_mail_template_en-US.html",
|
||||
branded_template_path="without-brand/email_code_login_mail_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="{application_title} 登录验证码",
|
||||
template_path="email_code_login_mail_template_zh-CN.html",
|
||||
branded_template_path="without-brand/email_code_login_mail_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.CHANGE_EMAIL_OLD: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Check your current email",
|
||||
template_path="change_mail_confirm_old_template_en-US.html",
|
||||
branded_template_path="without-brand/change_mail_confirm_old_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="检测您现在的邮箱",
|
||||
template_path="change_mail_confirm_old_template_zh-CN.html",
|
||||
branded_template_path="without-brand/change_mail_confirm_old_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.CHANGE_EMAIL_NEW: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Confirm your new email address",
|
||||
template_path="change_mail_confirm_new_template_en-US.html",
|
||||
branded_template_path="without-brand/change_mail_confirm_new_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="确认您的邮箱地址变更",
|
||||
template_path="change_mail_confirm_new_template_zh-CN.html",
|
||||
branded_template_path="without-brand/change_mail_confirm_new_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.CHANGE_EMAIL_COMPLETED: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Your login email has been changed",
|
||||
template_path="change_mail_completed_template_en-US.html",
|
||||
branded_template_path="without-brand/change_mail_completed_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="您的登录邮箱已更改",
|
||||
template_path="change_mail_completed_template_zh-CN.html",
|
||||
branded_template_path="without-brand/change_mail_completed_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.OWNER_TRANSFER_CONFIRM: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Verify Your Request to Transfer Workspace Ownership",
|
||||
template_path="transfer_workspace_owner_confirm_template_en-US.html",
|
||||
branded_template_path="without-brand/transfer_workspace_owner_confirm_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="验证您转移工作空间所有权的请求",
|
||||
template_path="transfer_workspace_owner_confirm_template_zh-CN.html",
|
||||
branded_template_path="without-brand/transfer_workspace_owner_confirm_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.OWNER_TRANSFER_OLD_NOTIFY: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Workspace ownership has been transferred",
|
||||
template_path="transfer_workspace_old_owner_notify_template_en-US.html",
|
||||
branded_template_path="without-brand/transfer_workspace_old_owner_notify_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="工作区所有权已转移",
|
||||
template_path="transfer_workspace_old_owner_notify_template_zh-CN.html",
|
||||
branded_template_path="without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.OWNER_TRANSFER_NEW_NOTIFY: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="You are now the owner of {WorkspaceName}",
|
||||
template_path="transfer_workspace_new_owner_notify_template_en-US.html",
|
||||
branded_template_path="without-brand/transfer_workspace_new_owner_notify_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="您现在是 {WorkspaceName} 的所有者",
|
||||
template_path="transfer_workspace_new_owner_notify_template_zh-CN.html",
|
||||
branded_template_path="without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.ACCOUNT_DELETION_SUCCESS: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Your Dify.AI Account Has Been Successfully Deleted",
|
||||
template_path="delete_account_success_template_en-US.html",
|
||||
branded_template_path="delete_account_success_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="您的 Dify.AI 账户已成功删除",
|
||||
template_path="delete_account_success_template_zh-CN.html",
|
||||
branded_template_path="delete_account_success_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.ACCOUNT_DELETION_VERIFICATION: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Dify.AI Account Deletion and Verification",
|
||||
template_path="delete_account_code_email_template_en-US.html",
|
||||
branded_template_path="delete_account_code_email_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="Dify.AI 账户删除和验证",
|
||||
template_path="delete_account_code_email_template_zh-CN.html",
|
||||
branded_template_path="delete_account_code_email_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.QUEUE_MONITOR_ALERT: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Alert: Dataset Queue pending tasks exceeded the limit",
|
||||
template_path="queue_monitor_alert_email_template_en-US.html",
|
||||
branded_template_path="queue_monitor_alert_email_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="警报:数据集队列待处理任务超过限制",
|
||||
template_path="queue_monitor_alert_email_template_zh-CN.html",
|
||||
branded_template_path="queue_monitor_alert_email_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.DOCUMENT_CLEAN_NOTIFY: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Dify Knowledge base auto disable notification",
|
||||
template_path="clean_document_job_mail_template-US.html",
|
||||
branded_template_path="clean_document_job_mail_template-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="Dify 知识库自动禁用通知",
|
||||
template_path="clean_document_job_mail_template_zh-CN.html",
|
||||
branded_template_path="clean_document_job_mail_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.TRIGGER_EVENTS_LIMIT_SANDBOX: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="You’ve reached your Sandbox Trigger Events limit",
|
||||
template_path="trigger_events_limit_template_en-US.html",
|
||||
branded_template_path="without-brand/trigger_events_limit_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="您的 Sandbox 触发事件额度已用尽",
|
||||
template_path="trigger_events_limit_template_zh-CN.html",
|
||||
branded_template_path="without-brand/trigger_events_limit_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.TRIGGER_EVENTS_LIMIT_PROFESSIONAL: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="You’ve reached your monthly Trigger Events limit",
|
||||
template_path="trigger_events_limit_template_en-US.html",
|
||||
branded_template_path="without-brand/trigger_events_limit_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="您的月度触发事件额度已用尽",
|
||||
template_path="trigger_events_limit_template_zh-CN.html",
|
||||
branded_template_path="without-brand/trigger_events_limit_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.TRIGGER_EVENTS_USAGE_WARNING_SANDBOX: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="You’re nearing your Sandbox Trigger Events limit",
|
||||
template_path="trigger_events_usage_warning_template_en-US.html",
|
||||
branded_template_path="without-brand/trigger_events_usage_warning_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="您的 Sandbox 触发事件额度接近上限",
|
||||
template_path="trigger_events_usage_warning_template_zh-CN.html",
|
||||
branded_template_path="without-brand/trigger_events_usage_warning_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.TRIGGER_EVENTS_USAGE_WARNING_PROFESSIONAL: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="You’re nearing your Monthly Trigger Events limit",
|
||||
template_path="trigger_events_usage_warning_template_en-US.html",
|
||||
branded_template_path="without-brand/trigger_events_usage_warning_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="您的月度触发事件额度接近上限",
|
||||
template_path="trigger_events_usage_warning_template_zh-CN.html",
|
||||
branded_template_path="without-brand/trigger_events_usage_warning_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.API_RATE_LIMIT_LIMIT_SANDBOX: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="You’ve reached your API Rate Limit",
|
||||
template_path="api_rate_limit_limit_template_en-US.html",
|
||||
branded_template_path="without-brand/api_rate_limit_limit_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="您的 API 速率额度已用尽",
|
||||
template_path="api_rate_limit_limit_template_zh-CN.html",
|
||||
branded_template_path="without-brand/api_rate_limit_limit_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.API_RATE_LIMIT_WARNING_SANDBOX: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="You’re nearing your API Rate Limit",
|
||||
template_path="api_rate_limit_warning_template_en-US.html",
|
||||
branded_template_path="without-brand/api_rate_limit_warning_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="您的 API 速率额度接近上限",
|
||||
template_path="api_rate_limit_warning_template_zh-CN.html",
|
||||
branded_template_path="without-brand/api_rate_limit_warning_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.EMAIL_REGISTER: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Register Your {application_title} Account",
|
||||
template_path="register_email_template_en-US.html",
|
||||
branded_template_path="without-brand/register_email_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="注册您的 {application_title} 账户",
|
||||
template_path="register_email_template_zh-CN.html",
|
||||
branded_template_path="without-brand/register_email_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Register Your {application_title} Account",
|
||||
template_path="register_email_when_account_exist_template_en-US.html",
|
||||
branded_template_path="without-brand/register_email_when_account_exist_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="注册您的 {application_title} 账户",
|
||||
template_path="register_email_when_account_exist_template_zh-CN.html",
|
||||
branded_template_path="without-brand/register_email_when_account_exist_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Reset Your {application_title} Password",
|
||||
template_path="reset_password_mail_when_account_not_exist_template_en-US.html",
|
||||
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="重置您的 {application_title} 密码",
|
||||
template_path="reset_password_mail_when_account_not_exist_template_zh-CN.html",
|
||||
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Reset Your {application_title} Password",
|
||||
template_path="reset_password_mail_when_account_not_exist_no_register_template_en-US.html",
|
||||
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="重置您的 {application_title} 密码",
|
||||
template_path="reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html",
|
||||
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
return EmailI18nConfig(templates=templates)
|
||||
|
||||
|
||||
# Singleton instance for application-wide use
|
||||
def get_default_email_i18n_service() -> EmailI18nService:
|
||||
"""Get configured email i18n service with default dependencies."""
|
||||
config = create_default_email_config()
|
||||
renderer = FlaskEmailRenderer()
|
||||
branding_service = FeatureBrandingService()
|
||||
sender = FlaskMailSender()
|
||||
|
||||
return EmailI18nService(
|
||||
config=config,
|
||||
renderer=renderer,
|
||||
branding_service=branding_service,
|
||||
sender=sender,
|
||||
)
|
||||
|
||||
|
||||
# Global instance
|
||||
_email_i18n_service: EmailI18nService | None = None
|
||||
|
||||
|
||||
def get_email_i18n_service() -> EmailI18nService:
|
||||
"""Get global email i18n service instance."""
|
||||
global _email_i18n_service
|
||||
if _email_i18n_service is None:
|
||||
_email_i18n_service = get_default_email_i18n_service()
|
||||
return _email_i18n_service
|
||||
15
dify/api/libs/exception.py
Normal file
15
dify/api/libs/exception.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
|
||||
class BaseHTTPException(HTTPException):
|
||||
error_code: str = "unknown"
|
||||
data: dict | None = None
|
||||
|
||||
def __init__(self, description=None, response=None):
|
||||
super().__init__(description, response)
|
||||
|
||||
self.data = {
|
||||
"code": self.error_code,
|
||||
"message": self.description,
|
||||
"status": self.code,
|
||||
}
|
||||
142
dify/api/libs/external_api.py
Normal file
142
dify/api/libs/external_api.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Blueprint, Flask, current_app, got_request_exception
|
||||
from flask_restx import Api
|
||||
from werkzeug.exceptions import HTTPException
|
||||
from werkzeug.http import HTTP_STATUS_CODES
|
||||
|
||||
from configs import dify_config
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from libs.token import build_force_logout_cookie_headers
|
||||
|
||||
|
||||
def http_status_message(code):
|
||||
return HTTP_STATUS_CODES.get(code, "")
|
||||
|
||||
|
||||
def register_external_error_handlers(api: Api):
|
||||
@api.errorhandler(HTTPException)
|
||||
def handle_http_exception(e: HTTPException):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
|
||||
# If Werkzeug already prepared a Response, just use it.
|
||||
if e.response is not None:
|
||||
return e.response
|
||||
|
||||
status_code = getattr(e, "code", 500) or 500
|
||||
|
||||
# Build a safe, dict-like payload
|
||||
default_data = {
|
||||
"code": re.sub(r"(?<!^)(?=[A-Z])", "_", type(e).__name__).lower(),
|
||||
"message": getattr(e, "description", http_status_message(status_code)),
|
||||
"status": status_code,
|
||||
}
|
||||
if default_data["message"] == "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)":
|
||||
default_data["message"] = "Invalid JSON payload received or JSON payload is empty."
|
||||
|
||||
# Use headers on the exception if present; otherwise none.
|
||||
headers = {}
|
||||
exc_headers = getattr(e, "headers", None)
|
||||
if exc_headers:
|
||||
headers.update(exc_headers)
|
||||
|
||||
# Payload per status
|
||||
if status_code == 406 and api.default_mediatype is None:
|
||||
data = {"code": "not_acceptable", "message": default_data["message"], "status": status_code}
|
||||
return data, status_code, headers
|
||||
elif status_code == 400:
|
||||
msg = default_data["message"]
|
||||
if isinstance(msg, Mapping) and msg:
|
||||
# Convert param errors like {"field": "reason"} into a friendly shape
|
||||
param_key, param_value = next(iter(msg.items()))
|
||||
data = {
|
||||
"code": "invalid_param",
|
||||
"message": str(param_value),
|
||||
"params": param_key,
|
||||
"status": status_code,
|
||||
}
|
||||
else:
|
||||
data = {**default_data}
|
||||
data.setdefault("code", "unknown")
|
||||
return data, status_code, headers
|
||||
else:
|
||||
data = {**default_data}
|
||||
data.setdefault("code", "unknown")
|
||||
# If you need WWW-Authenticate for 401, add it to headers
|
||||
if status_code == 401:
|
||||
headers["WWW-Authenticate"] = 'Bearer realm="api"'
|
||||
# Check if this is a forced logout error - clear cookies
|
||||
error_code = getattr(e, "error_code", None)
|
||||
if error_code == "unauthorized_and_force_logout":
|
||||
# Add Set-Cookie headers to clear auth cookies
|
||||
headers["Set-Cookie"] = build_force_logout_cookie_headers()
|
||||
return data, status_code, headers
|
||||
|
||||
_ = handle_http_exception
|
||||
|
||||
@api.errorhandler(ValueError)
|
||||
def handle_value_error(e: ValueError):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
status_code = 400
|
||||
data = {"code": "invalid_param", "message": str(e), "status": status_code}
|
||||
return data, status_code
|
||||
|
||||
_ = handle_value_error
|
||||
|
||||
@api.errorhandler(AppInvokeQuotaExceededError)
|
||||
def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
status_code = 429
|
||||
data = {"code": "too_many_requests", "message": str(e), "status": status_code}
|
||||
return data, status_code
|
||||
|
||||
_ = handle_quota_exceeded
|
||||
|
||||
@api.errorhandler(Exception)
|
||||
def handle_general_exception(e: Exception):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
|
||||
status_code = 500
|
||||
data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)})
|
||||
|
||||
# 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response)
|
||||
if not isinstance(data, dict):
|
||||
data = {"message": str(e)}
|
||||
|
||||
data.setdefault("code", "unknown")
|
||||
data.setdefault("status", status_code)
|
||||
|
||||
# Log stack
|
||||
exc_info: Any = sys.exc_info()
|
||||
if exc_info[1] is None:
|
||||
exc_info = (None, None, None)
|
||||
current_app.log_exception(exc_info)
|
||||
|
||||
return data, status_code
|
||||
|
||||
_ = handle_general_exception
|
||||
|
||||
|
||||
class ExternalApi(Api):
|
||||
_authorizations = {
|
||||
"Bearer": {
|
||||
"type": "apiKey",
|
||||
"in": "header",
|
||||
"name": "Authorization",
|
||||
"description": "Type: Bearer {your-api-key}",
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, app: Blueprint | Flask, *args, **kwargs):
|
||||
kwargs.setdefault("authorizations", self._authorizations)
|
||||
kwargs.setdefault("security", "Bearer")
|
||||
kwargs["add_specs"] = dify_config.SWAGGER_UI_ENABLED
|
||||
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
|
||||
|
||||
# manual separate call on construction and init_app to ensure configs in kwargs effective
|
||||
super().__init__(app=None, *args, **kwargs)
|
||||
self.init_app(app, **kwargs)
|
||||
register_external_error_handlers(self)
|
||||
30
dify/api/libs/file_utils.py
Normal file
30
dify/api/libs/file_utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def search_file_upwards(
|
||||
base_dir_path: Path,
|
||||
target_file_name: str,
|
||||
max_search_parent_depth: int,
|
||||
) -> Path:
|
||||
"""
|
||||
Find a target file in the current directory or its parent directories up to a specified depth.
|
||||
:param base_dir_path: Starting directory path to search from.
|
||||
:param target_file_name: Name of the file to search for.
|
||||
:param max_search_parent_depth: Maximum number of parent directories to search upwards.
|
||||
:return: Path of the file if found, otherwise None.
|
||||
"""
|
||||
current_path = base_dir_path.resolve()
|
||||
for _ in range(max_search_parent_depth):
|
||||
candidate_path = current_path / target_file_name
|
||||
if candidate_path.is_file():
|
||||
return candidate_path
|
||||
parent_path = current_path.parent
|
||||
if parent_path == current_path: # reached the root directory
|
||||
break
|
||||
else:
|
||||
current_path = parent_path
|
||||
|
||||
raise ValueError(
|
||||
f"File '{target_file_name}' not found in the directory '{base_dir_path.resolve()}' or its parent directories"
|
||||
f" in depth of {max_search_parent_depth}."
|
||||
)
|
||||
66
dify/api/libs/flask_utils.py
Normal file
66
dify/api/libs/flask_utils.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import contextvars
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import TypeVar
|
||||
|
||||
from flask import Flask, g
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def preserve_flask_contexts(
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> Iterator[None]:
|
||||
"""
|
||||
A context manager that handles:
|
||||
1. flask-login's UserProxy copy
|
||||
2. ContextVars copy
|
||||
3. flask_app.app_context()
|
||||
|
||||
This context manager ensures that the Flask application context is properly set up,
|
||||
the current user is preserved across context boundaries, and any provided context variables
|
||||
are set within the new context.
|
||||
|
||||
Note:
|
||||
This manager aims to allow use current_user cross thread and app context,
|
||||
but it's not the recommend use, it's better to pass user directly in parameters.
|
||||
|
||||
Args:
|
||||
flask_app: The Flask application instance
|
||||
context_vars: contextvars.Context object containing context variables to be set in the new context
|
||||
|
||||
Yields:
|
||||
None
|
||||
|
||||
Example:
|
||||
```python
|
||||
with preserve_flask_contexts(flask_app, context_vars=context_vars):
|
||||
# Code that needs Flask app context and context variables
|
||||
# Current user will be preserved if available
|
||||
```
|
||||
"""
|
||||
# Set context variables if provided
|
||||
if context_vars:
|
||||
for var, val in context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Save current user before entering new app context
|
||||
saved_user = None
|
||||
# Check for user in g (works in both request context and app context)
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Enter Flask app context
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# Restore user in new app context if it was saved
|
||||
if saved_user is not None:
|
||||
g._login_user = saved_user
|
||||
|
||||
# Yield control back to the caller
|
||||
yield
|
||||
finally:
|
||||
# Any cleanup can be added here if needed
|
||||
pass
|
||||
241
dify/api/libs/gmpy2_pkcs10aep_cipher.py
Normal file
241
dify/api/libs/gmpy2_pkcs10aep_cipher.py
Normal file
@@ -0,0 +1,241 @@
|
||||
#
|
||||
# Cipher/PKCS1_OAEP.py : PKCS#1 OAEP
|
||||
#
|
||||
# ===================================================================
|
||||
# The contents of this file are dedicated to the public domain. To
|
||||
# the extent that dedication to the public domain is not available,
|
||||
# everyone is granted a worldwide, perpetual, royalty-free,
|
||||
# non-exclusive license to exercise all rights associated with the
|
||||
# contents of this file for any purpose whatsoever.
|
||||
# No rights are reserved.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
|
||||
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
# ===================================================================
|
||||
|
||||
from hashlib import sha1
|
||||
|
||||
import Crypto.Hash.SHA1
|
||||
import Crypto.Util.number
|
||||
import gmpy2
|
||||
from Crypto import Random
|
||||
from Crypto.Signature.pss import MGF1
|
||||
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
|
||||
from Crypto.Util.py3compat import bord
|
||||
from Crypto.Util.strxor import strxor
|
||||
|
||||
|
||||
class PKCS1OAepCipher:
|
||||
"""Cipher object for PKCS#1 v1.5 OAEP.
|
||||
Do not create directly: use :func:`new` instead."""
|
||||
|
||||
def __init__(self, key, hashAlgo, mgfunc, label, randfunc):
|
||||
"""Initialize this PKCS#1 OAEP cipher object.
|
||||
|
||||
:Parameters:
|
||||
key : an RSA key object
|
||||
If a private half is given, both encryption and decryption are possible.
|
||||
If a public half is given, only encryption is possible.
|
||||
hashAlgo : hash object
|
||||
The hash function to use. This can be a module under `Crypto.Hash`
|
||||
or an existing hash object created from any of such modules. If not specified,
|
||||
`Crypto.Hash.SHA1` is used.
|
||||
mgfunc : callable
|
||||
A mask generation function that accepts two parameters: a string to
|
||||
use as seed, and the length of the mask to generate, in bytes.
|
||||
If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice).
|
||||
label : bytes/bytearray/memoryview
|
||||
A label to apply to this particular encryption. If not specified,
|
||||
an empty string is used. Specifying a label does not improve
|
||||
security.
|
||||
randfunc : callable
|
||||
A function that returns random bytes.
|
||||
|
||||
:attention: Modify the mask generation function only if you know what you are doing.
|
||||
Sender and receiver must use the same one.
|
||||
"""
|
||||
self._key = key
|
||||
|
||||
if hashAlgo:
|
||||
self._hashObj = hashAlgo
|
||||
else:
|
||||
self._hashObj = Crypto.Hash.SHA1
|
||||
|
||||
if mgfunc:
|
||||
self._mgf = mgfunc
|
||||
else:
|
||||
self._mgf = lambda x, y: MGF1(x, y, self._hashObj)
|
||||
|
||||
self._label = bytes(label)
|
||||
self._randfunc = randfunc
|
||||
|
||||
def can_encrypt(self):
|
||||
"""Legacy function to check if you can call :meth:`encrypt`.
|
||||
|
||||
.. deprecated:: 3.0"""
|
||||
return self._key.can_encrypt()
|
||||
|
||||
def can_decrypt(self):
|
||||
"""Legacy function to check if you can call :meth:`decrypt`.
|
||||
|
||||
.. deprecated:: 3.0"""
|
||||
return self._key.can_decrypt()
|
||||
|
||||
def encrypt(self, message):
|
||||
"""Encrypt a message with PKCS#1 OAEP.
|
||||
|
||||
:param message:
|
||||
The message to encrypt, also known as plaintext. It can be of
|
||||
variable length, but not longer than the RSA modulus (in bytes)
|
||||
minus 2, minus twice the hash output size.
|
||||
For instance, if you use RSA 2048 and SHA-256, the longest message
|
||||
you can encrypt is 190 byte long.
|
||||
:type message: bytes/bytearray/memoryview
|
||||
|
||||
:returns: The ciphertext, as large as the RSA modulus.
|
||||
:rtype: bytes
|
||||
|
||||
:raises ValueError:
|
||||
if the message is too long.
|
||||
"""
|
||||
|
||||
# See 7.1.1 in RFC3447
|
||||
modBits = Crypto.Util.number.size(self._key.n)
|
||||
k = ceil_div(modBits, 8) # Convert from bits to bytes
|
||||
hLen = self._hashObj.digest_size
|
||||
mLen = len(message)
|
||||
|
||||
# Step 1b
|
||||
ps_len = k - mLen - 2 * hLen - 2
|
||||
if ps_len < 0:
|
||||
raise ValueError("Plaintext is too long.")
|
||||
# Step 2a
|
||||
lHash = sha1(self._label).digest()
|
||||
# Step 2b
|
||||
ps = b"\x00" * ps_len
|
||||
# Step 2c
|
||||
db = lHash + ps + b"\x01" + bytes(message)
|
||||
# Step 2d
|
||||
ros = self._randfunc(hLen)
|
||||
# Step 2e
|
||||
dbMask = self._mgf(ros, k - hLen - 1)
|
||||
# Step 2f
|
||||
maskedDB = strxor(db, dbMask)
|
||||
# Step 2g
|
||||
seedMask = self._mgf(maskedDB, hLen)
|
||||
# Step 2h
|
||||
maskedSeed = strxor(ros, seedMask)
|
||||
# Step 2i
|
||||
em = b"\x00" + maskedSeed + maskedDB
|
||||
# Step 3a (OS2IP)
|
||||
em_int = bytes_to_long(em)
|
||||
# Step 3b (RSAEP)
|
||||
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
|
||||
# Step 3c (I2OSP)
|
||||
c = long_to_bytes(m_int, k)
|
||||
return c
|
||||
|
||||
def decrypt(self, ciphertext):
|
||||
"""Decrypt a message with PKCS#1 OAEP.
|
||||
|
||||
:param ciphertext: The encrypted message.
|
||||
:type ciphertext: bytes/bytearray/memoryview
|
||||
|
||||
:returns: The original message (plaintext).
|
||||
:rtype: bytes
|
||||
|
||||
:raises ValueError:
|
||||
if the ciphertext has the wrong length, or if decryption
|
||||
fails the integrity check (in which case, the decryption
|
||||
key is probably wrong).
|
||||
:raises TypeError:
|
||||
if the RSA key has no private half (i.e. you are trying
|
||||
to decrypt using a public key).
|
||||
"""
|
||||
# See 7.1.2 in RFC3447
|
||||
modBits = Crypto.Util.number.size(self._key.n)
|
||||
k = ceil_div(modBits, 8) # Convert from bits to bytes
|
||||
hLen = self._hashObj.digest_size
|
||||
# Step 1b and 1c
|
||||
if len(ciphertext) != k or k < hLen + 2:
|
||||
raise ValueError("Ciphertext with incorrect length.")
|
||||
# Step 2a (O2SIP)
|
||||
ct_int = bytes_to_long(ciphertext)
|
||||
# Step 2b (RSADP)
|
||||
# m_int = self._key._decrypt(ct_int)
|
||||
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
|
||||
# Complete step 2c (I2OSP)
|
||||
em = long_to_bytes(m_int, k)
|
||||
# Step 3a
|
||||
lHash = sha1(self._label).digest()
|
||||
# Step 3b
|
||||
y = em[0]
|
||||
# y must be 0, but we MUST NOT check it here in order not to
|
||||
# allow attacks like Manger's (http://dl.acm.org/citation.cfm?id=704143)
|
||||
maskedSeed = em[1 : hLen + 1]
|
||||
maskedDB = em[hLen + 1 :]
|
||||
# Step 3c
|
||||
seedMask = self._mgf(maskedDB, hLen)
|
||||
# Step 3d
|
||||
seed = strxor(maskedSeed, seedMask)
|
||||
# Step 3e
|
||||
dbMask = self._mgf(seed, k - hLen - 1)
|
||||
# Step 3f
|
||||
db = strxor(maskedDB, dbMask)
|
||||
# Step 3g
|
||||
one_pos = hLen + db[hLen:].find(b"\x01")
|
||||
lHash1 = db[:hLen]
|
||||
invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type]
|
||||
hash_compare = strxor(lHash1, lHash)
|
||||
for x in hash_compare:
|
||||
invalid |= bord(x) # type: ignore[arg-type]
|
||||
for x in db[hLen:one_pos]:
|
||||
invalid |= bord(x) # type: ignore[arg-type]
|
||||
if invalid != 0:
|
||||
raise ValueError("Incorrect decryption.")
|
||||
# Step 4
|
||||
return db[one_pos + 1 :]
|
||||
|
||||
|
||||
def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None):
|
||||
"""Return a cipher object :class:`PKCS1OAEP_Cipher`
|
||||
that can be used to perform PKCS#1 OAEP encryption or decryption.
|
||||
|
||||
:param key:
|
||||
The key object to use to encrypt or decrypt the message.
|
||||
Decryption is only possible with a private RSA key.
|
||||
:type key: RSA key object
|
||||
|
||||
:param hashAlgo:
|
||||
The hash function to use. This can be a module under `Crypto.Hash`
|
||||
or an existing hash object created from any of such modules.
|
||||
If not specified, `Crypto.Hash.SHA1` is used.
|
||||
:type hashAlgo: hash object
|
||||
|
||||
:param mgfunc:
|
||||
A mask generation function that accepts two parameters: a string to
|
||||
use as seed, and the length of the mask to generate, in bytes.
|
||||
If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice).
|
||||
:type mgfunc: callable
|
||||
|
||||
:param label:
|
||||
A label to apply to this particular encryption. If not specified,
|
||||
an empty string is used. Specifying a label does not improve
|
||||
security.
|
||||
:type label: bytes/bytearray/memoryview
|
||||
|
||||
:param randfunc:
|
||||
A function that returns random bytes.
|
||||
The default is `Random.get_random_bytes`.
|
||||
:type randfunc: callable
|
||||
"""
|
||||
|
||||
if randfunc is None:
|
||||
randfunc = Random.get_random_bytes
|
||||
return PKCS1OAepCipher(key, hashAlgo, mgfunc, label, randfunc)
|
||||
380
dify/api/libs/helper.py
Normal file
380
dify/api/libs/helper.py
Normal file
@@ -0,0 +1,380 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import struct
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restx import fields
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
from core.file import helpers as file_helpers
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import Account
|
||||
from models.model import EndUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
|
||||
"""
|
||||
Extract tenant_id from Account or EndUser object.
|
||||
|
||||
Args:
|
||||
user: Account or EndUser object
|
||||
|
||||
Returns:
|
||||
tenant_id string if available, None otherwise
|
||||
|
||||
Raises:
|
||||
ValueError: If user is neither Account nor EndUser
|
||||
"""
|
||||
from models import Account
|
||||
from models.model import EndUser
|
||||
|
||||
if isinstance(user, Account):
|
||||
return user.current_tenant_id
|
||||
elif isinstance(user, EndUser):
|
||||
return user.tenant_id
|
||||
else:
|
||||
raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.")
|
||||
|
||||
|
||||
def run(script):
|
||||
return subprocess.getstatusoutput("source /root/.bashrc && " + script)
|
||||
|
||||
|
||||
class AppIconUrlField(fields.Raw):
|
||||
def output(self, key, obj, **kwargs):
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
from models.model import App, IconType, Site
|
||||
|
||||
if isinstance(obj, dict) and "app" in obj:
|
||||
obj = obj["app"]
|
||||
|
||||
if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE:
|
||||
return file_helpers.get_signed_file_url(obj.icon)
|
||||
return None
|
||||
|
||||
|
||||
class AvatarUrlField(fields.Raw):
|
||||
def output(self, key, obj, **kwargs):
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
from models import Account
|
||||
|
||||
if isinstance(obj, Account) and obj.avatar is not None:
|
||||
if obj.avatar.startswith(("http://", "https://")):
|
||||
return obj.avatar
|
||||
return file_helpers.get_signed_file_url(obj.avatar)
|
||||
return None
|
||||
|
||||
|
||||
class TimestampField(fields.Raw):
|
||||
def format(self, value) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def email(email):
|
||||
# Define a regex pattern for email addresses
|
||||
pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$"
|
||||
# Check if the email matches the pattern
|
||||
if re.match(pattern, email) is not None:
|
||||
return email
|
||||
|
||||
error = f"{email} is not a valid email."
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def uuid_value(value):
|
||||
if value == "":
|
||||
return str(value)
|
||||
|
||||
try:
|
||||
uuid_obj = uuid.UUID(value)
|
||||
return str(uuid_obj)
|
||||
except ValueError:
|
||||
error = f"{value} is not a valid uuid."
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def alphanumeric(value: str):
|
||||
# check if the value is alphanumeric and underlined
|
||||
if re.match(r"^[a-zA-Z0-9_]+$", value):
|
||||
return value
|
||||
|
||||
raise ValueError(f"{value} is not a valid alphanumeric value")
|
||||
|
||||
|
||||
def timestamp_value(timestamp):
|
||||
try:
|
||||
int_timestamp = int(timestamp)
|
||||
if int_timestamp < 0:
|
||||
raise ValueError
|
||||
return int_timestamp
|
||||
except ValueError:
|
||||
error = f"{timestamp} is not a valid timestamp."
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
class StrLen:
|
||||
"""Restrict input to an integer in a range (inclusive)"""
|
||||
|
||||
def __init__(self, max_length, argument="argument"):
|
||||
self.max_length = max_length
|
||||
self.argument = argument
|
||||
|
||||
def __call__(self, value):
|
||||
length = len(value)
|
||||
if length > self.max_length:
|
||||
error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format(
|
||||
arg=self.argument, val=value, length=self.max_length
|
||||
)
|
||||
raise ValueError(error)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class DatetimeString:
|
||||
def __init__(self, format, argument="argument"):
|
||||
self.format = format
|
||||
self.argument = argument
|
||||
|
||||
def __call__(self, value):
|
||||
try:
|
||||
datetime.strptime(value, self.format)
|
||||
except ValueError:
|
||||
error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format(
|
||||
arg=self.argument, val=value, format=self.format
|
||||
)
|
||||
raise ValueError(error)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def timezone(timezone_string):
|
||||
if timezone_string and timezone_string in available_timezones():
|
||||
return timezone_string
|
||||
|
||||
error = f"{timezone_string} is not a valid timezone."
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def convert_datetime_to_date(field, target_timezone: str = ":tz"):
|
||||
if dify_config.DB_TYPE == "postgresql":
|
||||
return f"DATE(DATE_TRUNC('day', {field} AT TIME ZONE 'UTC' AT TIME ZONE {target_timezone}))"
|
||||
elif dify_config.DB_TYPE == "mysql":
|
||||
return f"DATE(CONVERT_TZ({field}, 'UTC', {target_timezone}))"
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database type: {dify_config.DB_TYPE}")
|
||||
|
||||
|
||||
def generate_string(n):
|
||||
letters_digits = string.ascii_letters + string.digits
|
||||
result = ""
|
||||
for _ in range(n):
|
||||
result += secrets.choice(letters_digits)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def extract_remote_ip(request) -> str:
|
||||
if request.headers.get("CF-Connecting-IP"):
|
||||
return cast(str, request.headers.get("CF-Connecting-IP"))
|
||||
elif request.headers.getlist("X-Forwarded-For"):
|
||||
return cast(str, request.headers.getlist("X-Forwarded-For")[0])
|
||||
else:
|
||||
return cast(str, request.remote_addr)
|
||||
|
||||
|
||||
def generate_text_hash(text: str) -> str:
|
||||
hash_text = str(text) + "None"
|
||||
return sha256(hash_text.encode()).hexdigest()
|
||||
|
||||
|
||||
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json")
|
||||
else:
|
||||
|
||||
def generate() -> Generator:
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
|
||||
|
||||
|
||||
def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
"""
|
||||
This function is used to return a response with a length prefix.
|
||||
Magic number is a one byte number that indicates the type of the response.
|
||||
|
||||
For a compatibility with latest plugin daemon https://github.com/langgenius/dify-plugin-daemon/pull/341
|
||||
Avoid using line-based response, it leads a memory issue.
|
||||
|
||||
We uses following format:
|
||||
| Field | Size | Description |
|
||||
|---------------|----------|---------------------------------|
|
||||
| Magic Number | 1 byte | Magic number identifier |
|
||||
| Reserved | 1 byte | Reserved field |
|
||||
| Header Length | 2 bytes | Header length (usually 0xa) |
|
||||
| Data Length | 4 bytes | Length of the data |
|
||||
| Reserved | 6 bytes | Reserved fields |
|
||||
| Data | Variable | Actual data content |
|
||||
|
||||
| Reserved Fields | Header | Data |
|
||||
|-----------------|----------|----------|
|
||||
| 4 bytes total | Variable | Variable |
|
||||
|
||||
all data is in little endian
|
||||
"""
|
||||
|
||||
def pack_response_with_length_prefix(response: bytes) -> bytes:
|
||||
header_length = 0xA
|
||||
data_length = len(response)
|
||||
# | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
|
||||
return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
|
||||
|
||||
if isinstance(response, dict):
|
||||
return Response(
|
||||
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
|
||||
status=200,
|
||||
mimetype="application/json",
|
||||
)
|
||||
elif isinstance(response, BaseModel):
|
||||
return Response(
|
||||
response=pack_response_with_length_prefix(response.model_dump_json().encode("utf-8")),
|
||||
status=200,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
if isinstance(chunk, str):
|
||||
yield pack_response_with_length_prefix(chunk.encode("utf-8"))
|
||||
else:
|
||||
yield pack_response_with_length_prefix(chunk)
|
||||
|
||||
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
|
||||
|
||||
|
||||
class TokenManager:
|
||||
@classmethod
|
||||
def generate_token(
|
||||
cls,
|
||||
token_type: str,
|
||||
account: Optional["Account"] = None,
|
||||
email: str | None = None,
|
||||
additional_data: dict | None = None,
|
||||
) -> str:
|
||||
if account is None and email is None:
|
||||
raise ValueError("Account or email must be provided")
|
||||
|
||||
account_id = account.id if account else None
|
||||
account_email = account.email if account else email
|
||||
|
||||
if account_id:
|
||||
old_token = cls._get_current_token_for_account(account_id, token_type)
|
||||
if old_token:
|
||||
if isinstance(old_token, bytes):
|
||||
old_token = old_token.decode("utf-8")
|
||||
cls.revoke_token(old_token, token_type)
|
||||
|
||||
token = str(uuid.uuid4())
|
||||
token_data = {"account_id": account_id, "email": account_email, "token_type": token_type}
|
||||
if additional_data:
|
||||
token_data.update(additional_data)
|
||||
|
||||
expiry_minutes = dify_config.model_dump().get(f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES")
|
||||
if expiry_minutes is None:
|
||||
raise ValueError(f"Expiry minutes for {token_type} token is not set")
|
||||
token_key = cls._get_token_key(token, token_type)
|
||||
expiry_seconds = int(expiry_minutes * 60)
|
||||
redis_client.setex(token_key, expiry_seconds, json.dumps(token_data))
|
||||
|
||||
if account_id:
|
||||
cls._set_current_token_for_account(account_id, token, token_type, expiry_minutes)
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def _get_token_key(cls, token: str, token_type: str) -> str:
|
||||
return f"{token_type}:token:{token}"
|
||||
|
||||
@classmethod
|
||||
def revoke_token(cls, token: str, token_type: str):
|
||||
token_key = cls._get_token_key(token, token_type)
|
||||
redis_client.delete(token_key)
|
||||
|
||||
@classmethod
|
||||
def get_token_data(cls, token: str, token_type: str) -> dict[str, Any] | None:
|
||||
key = cls._get_token_key(token, token_type)
|
||||
token_data_json = redis_client.get(key)
|
||||
if token_data_json is None:
|
||||
logger.warning("%s token %s not found with key %s", token_type, token, key)
|
||||
return None
|
||||
token_data: dict[str, Any] | None = json.loads(token_data_json)
|
||||
return token_data
|
||||
|
||||
@classmethod
|
||||
def _get_current_token_for_account(cls, account_id: str, token_type: str) -> str | None:
|
||||
key = cls._get_account_token_key(account_id, token_type)
|
||||
current_token: str | None = redis_client.get(key)
|
||||
return current_token
|
||||
|
||||
@classmethod
|
||||
def _set_current_token_for_account(
|
||||
cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float]
|
||||
):
|
||||
key = cls._get_account_token_key(account_id, token_type)
|
||||
expiry_seconds = int(expiry_minutes * 60)
|
||||
redis_client.setex(key, expiry_seconds, token)
|
||||
|
||||
@classmethod
|
||||
def _get_account_token_key(cls, account_id: str, token_type: str) -> str:
|
||||
return f"{token_type}:account:{account_id}"
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(self, prefix: str, max_attempts: int, time_window: int):
|
||||
self.prefix = prefix
|
||||
self.max_attempts = max_attempts
|
||||
self.time_window = time_window
|
||||
|
||||
def _get_key(self, email: str) -> str:
|
||||
return f"{self.prefix}:{email}"
|
||||
|
||||
def is_rate_limited(self, email: str) -> bool:
|
||||
key = self._get_key(email)
|
||||
current_time = int(time.time())
|
||||
window_start_time = current_time - self.time_window
|
||||
|
||||
redis_client.zremrangebyscore(key, "-inf", window_start_time)
|
||||
attempts = redis_client.zcard(key)
|
||||
|
||||
if attempts and int(attempts) >= self.max_attempts:
|
||||
return True
|
||||
return False
|
||||
|
||||
def increment_rate_limit(self, email: str):
|
||||
key = self._get_key(email)
|
||||
current_time = int(time.time())
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.expire(key, self.time_window * 2)
|
||||
5
dify/api/libs/infinite_scroll_pagination.py
Normal file
5
dify/api/libs/infinite_scroll_pagination.py
Normal file
@@ -0,0 +1,5 @@
|
||||
class InfiniteScrollPagination:
|
||||
def __init__(self, data, limit, has_more):
|
||||
self.data = data
|
||||
self.limit = limit
|
||||
self.has_more = has_more
|
||||
52
dify/api/libs/json_in_md_parser.py
Normal file
52
dify/api/libs/json_in_md_parser.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import json
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
|
||||
|
||||
def parse_json_markdown(json_string: str):
|
||||
# Get json from the backticks/braces
|
||||
json_string = json_string.strip()
|
||||
starts = ["```json", "```", "``", "`", "{", "["]
|
||||
ends = ["```", "``", "`", "}", "]"]
|
||||
end_index = -1
|
||||
start_index = 0
|
||||
parsed: dict = {}
|
||||
for s in starts:
|
||||
start_index = json_string.find(s)
|
||||
if start_index != -1:
|
||||
if json_string[start_index] not in ("{", "["):
|
||||
start_index += len(s)
|
||||
break
|
||||
if start_index != -1:
|
||||
for e in ends:
|
||||
end_index = json_string.rfind(e, start_index)
|
||||
if end_index != -1:
|
||||
if json_string[end_index] in ("}", "]"):
|
||||
end_index += 1
|
||||
break
|
||||
if start_index != -1 and end_index != -1 and start_index < end_index:
|
||||
extracted_content = json_string[start_index:end_index].strip()
|
||||
parsed = json.loads(extracted_content)
|
||||
else:
|
||||
raise ValueError("could not find json block in the output.")
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def parse_and_check_json_markdown(text: str, expected_keys: list[str]):
|
||||
try:
|
||||
json_obj = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OutputParserError(f"got invalid json object. error: {e}")
|
||||
|
||||
if isinstance(json_obj, list):
|
||||
if len(json_obj) == 1 and isinstance(json_obj[0], dict):
|
||||
json_obj = json_obj[0]
|
||||
else:
|
||||
raise OutputParserError(f"got invalid return object. obj:{json_obj}")
|
||||
for key in expected_keys:
|
||||
if key not in json_obj:
|
||||
raise OutputParserError(
|
||||
f"got invalid return object. expected key `{key}` to be present, but got {json_obj}"
|
||||
)
|
||||
return json_obj
|
||||
98
dify/api/libs/login.py
Normal file
98
dify/api/libs/login.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app, g, has_request_context, request
|
||||
from flask_login.config import EXEMPT_METHODS
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from configs import dify_config
|
||||
from libs.token import check_csrf_token
|
||||
from models import Account
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def current_account_with_tenant():
|
||||
"""
|
||||
Resolve the underlying account for the current user proxy and ensure tenant context exists.
|
||||
Allows tests to supply plain Account mocks without the LocalProxy helper.
|
||||
"""
|
||||
user_proxy = current_user
|
||||
|
||||
get_current_object = getattr(user_proxy, "_get_current_object", None)
|
||||
user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore
|
||||
|
||||
if not isinstance(user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
assert user.current_tenant_id is not None, "The tenant information should be loaded."
|
||||
return user, user.current_tenant_id
|
||||
|
||||
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def login_required(func: Callable[P, R]):
|
||||
"""
|
||||
If you decorate a view with this, it will ensure that the current user is
|
||||
logged in and authenticated before calling the actual view. (If they are
|
||||
not, it calls the :attr:`LoginManager.unauthorized` callback.) For
|
||||
example::
|
||||
|
||||
@app.route('/post')
|
||||
@login_required
|
||||
def post():
|
||||
pass
|
||||
|
||||
If there are only certain times you need to require that your user is
|
||||
logged in, you can do so with::
|
||||
|
||||
if not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized()
|
||||
|
||||
...which is essentially the code that this function adds to your views.
|
||||
|
||||
It can be convenient to globally turn off authentication when unit testing.
|
||||
To enable this, if the application configuration variable `LOGIN_DISABLED`
|
||||
is set to `True`, this decorator will be ignored.
|
||||
|
||||
.. Note ::
|
||||
|
||||
Per `W3 guidelines for CORS preflight requests
|
||||
<http://www.w3.org/TR/cors/#cross-origin-request-with-preflight-0>`_,
|
||||
HTTP ``OPTIONS`` requests are exempt from login checks.
|
||||
|
||||
:param func: The view function to decorate.
|
||||
:type func: function
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||
pass
|
||||
elif current_user is not None and not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized() # type: ignore
|
||||
# we put csrf validation here for less conflicts
|
||||
# TODO: maybe find a better place for it.
|
||||
check_csrf_token(request, current_user.id)
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
def _get_user() -> EndUser | Account | None:
|
||||
if has_request_context():
|
||||
if "_login_user" not in g:
|
||||
current_app.login_manager._load_user() # type: ignore
|
||||
|
||||
return g._login_user
|
||||
|
||||
return None
|
||||
|
||||
|
||||
#: A proxy for the current user. If no user is logged in, this will be an
|
||||
#: anonymous user
|
||||
# NOTE: Any here, but use _get_current_object to check the fields
|
||||
current_user: Any = LocalProxy(lambda: _get_user())
|
||||
54
dify/api/libs/module_loading.py
Normal file
54
dify/api/libs/module_loading.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
Module loading utilities similar to Django's module_loading.
|
||||
|
||||
Reference implementation from Django:
|
||||
https://github.com/django/django/blob/main/django/utils/module_loading.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
def cached_import(module_path: str, class_name: str):
|
||||
"""
|
||||
Import a module and return the named attribute/class from it, with caching.
|
||||
|
||||
Args:
|
||||
module_path: The module path to import from
|
||||
class_name: The attribute/class name to retrieve
|
||||
|
||||
Returns:
|
||||
The imported attribute/class
|
||||
"""
|
||||
if not (
|
||||
(module := sys.modules.get(module_path))
|
||||
and (spec := getattr(module, "__spec__", None))
|
||||
and getattr(spec, "_initializing", False) is False
|
||||
):
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def import_string(dotted_path: str):
|
||||
"""
|
||||
Import a dotted module path and return the attribute/class designated by
|
||||
the last name in the path. Raise ImportError if the import failed.
|
||||
|
||||
Args:
|
||||
dotted_path: Full module path to the class (e.g., 'module.submodule.ClassName')
|
||||
|
||||
Returns:
|
||||
The imported class or attribute
|
||||
|
||||
Raises:
|
||||
ImportError: If the module or attribute cannot be imported
|
||||
"""
|
||||
try:
|
||||
module_path, class_name = dotted_path.rsplit(".", 1)
|
||||
except ValueError as err:
|
||||
raise ImportError(f"{dotted_path} doesn't look like a module path") from err
|
||||
|
||||
try:
|
||||
return cached_import(module_path, class_name)
|
||||
except AttributeError as err:
|
||||
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class') from err
|
||||
132
dify/api/libs/oauth.py
Normal file
132
dify/api/libs/oauth.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthUserInfo:
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
|
||||
|
||||
class OAuth:
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_user_info(self, token: str) -> OAuthUserInfo:
|
||||
raw_info = self.get_raw_user_info(token)
|
||||
return self._transform_user_info(raw_info)
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class GitHubOAuth(OAuth):
|
||||
_AUTH_URL = "https://github.com/login/oauth/authorize"
|
||||
_TOKEN_URL = "https://github.com/login/oauth/access_token"
|
||||
_USER_INFO_URL = "https://api.github.com/user"
|
||||
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None):
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": "user:email", # Request only basic user information
|
||||
}
|
||||
if invite_token:
|
||||
params["state"] = invite_token
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
access_token = response_json.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
raise ValueError(f"Error in GitHub OAuth: {response_json}")
|
||||
|
||||
return access_token
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
headers = {"Authorization": f"token {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
user_info = response.json()
|
||||
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_info = email_response.json()
|
||||
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
|
||||
|
||||
return {**user_info, "email": primary_email.get("email", "")}
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
email = raw_info.get("email")
|
||||
if not email:
|
||||
email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email)
|
||||
|
||||
|
||||
class GoogleOAuth(OAuth):
|
||||
_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None):
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": "openid email",
|
||||
}
|
||||
if invite_token:
|
||||
params["state"] = invite_token
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
access_token = response_json.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
raise ValueError(f"Error in Google OAuth: {response_json}")
|
||||
|
||||
return access_token
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])
|
||||
305
dify/api/libs/oauth_data_source.py
Normal file
305
dify/api/libs/oauth_data_source.py
Normal file
@@ -0,0 +1,305 @@
|
||||
import urllib.parse
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.source import DataSourceOauthBinding
|
||||
|
||||
|
||||
class OAuthDataSource:
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class NotionOAuth(OAuthDataSource):
|
||||
_AUTH_URL = "https://api.notion.com/v1/oauth/authorize"
|
||||
_TOKEN_URL = "https://api.notion.com/v1/oauth/token"
|
||||
_NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search"
|
||||
_NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
|
||||
_NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
|
||||
|
||||
def get_authorization_url(self):
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"owner": "user",
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
|
||||
headers = {"Accept": "application/json"}
|
||||
auth = (self.client_id, self.client_secret)
|
||||
response = httpx.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
access_token = response_json.get("access_token")
|
||||
if not access_token:
|
||||
raise ValueError(f"Error in Notion OAuth: {response_json}")
|
||||
workspace_name = response_json.get("workspace_name")
|
||||
workspace_icon = response_json.get("workspace_icon")
|
||||
workspace_id = response_json.get("workspace_id")
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(access_token)
|
||||
source_info = {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.access_token == access_token,
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=source_info,
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
|
||||
def save_internal_access_token(self, access_token: str):
|
||||
workspace_name = self.notion_workspace_name(access_token)
|
||||
workspace_icon = None
|
||||
workspace_id = current_user.current_tenant_id
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(access_token)
|
||||
source_info = {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.access_token == access_token,
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=source_info,
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
|
||||
def sync_data_source(self, binding_id: str):
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.id == binding_id,
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
)
|
||||
)
|
||||
|
||||
if data_source_binding:
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(data_source_binding.access_token)
|
||||
source_info = data_source_binding.source_info
|
||||
new_source_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
data_source_binding.source_info = new_source_info
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source binding not found")
|
||||
|
||||
def get_authorized_pages(self, access_token: str):
|
||||
pages = []
|
||||
page_results = self.notion_page_search(access_token)
|
||||
database_results = self.notion_database_search(access_token)
|
||||
# get page detail
|
||||
for page_result in page_results:
|
||||
page_id = page_result["id"]
|
||||
page_name = "Untitled"
|
||||
for key in page_result["properties"]:
|
||||
if "title" in page_result["properties"][key] and page_result["properties"][key]["title"]:
|
||||
title_list = page_result["properties"][key]["title"]
|
||||
if len(title_list) > 0 and "plain_text" in title_list[0]:
|
||||
page_name = title_list[0]["plain_text"]
|
||||
page_icon = page_result["icon"]
|
||||
if page_icon:
|
||||
icon_type = page_icon["type"]
|
||||
if icon_type in {"external", "file"}:
|
||||
url = page_icon[icon_type]["url"]
|
||||
icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"}
|
||||
else:
|
||||
icon = {"type": "emoji", "emoji": page_icon[icon_type]}
|
||||
else:
|
||||
icon = None
|
||||
parent = page_result["parent"]
|
||||
parent_type = parent["type"]
|
||||
if parent_type == "block_id":
|
||||
parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type])
|
||||
elif parent_type == "workspace":
|
||||
parent_id = "root"
|
||||
else:
|
||||
parent_id = parent[parent_type]
|
||||
page = {
|
||||
"page_id": page_id,
|
||||
"page_name": page_name,
|
||||
"page_icon": icon,
|
||||
"parent_id": parent_id,
|
||||
"type": "page",
|
||||
}
|
||||
pages.append(page)
|
||||
# get database detail
|
||||
for database_result in database_results:
|
||||
page_id = database_result["id"]
|
||||
if len(database_result["title"]) > 0:
|
||||
page_name = database_result["title"][0]["plain_text"]
|
||||
else:
|
||||
page_name = "Untitled"
|
||||
page_icon = database_result["icon"]
|
||||
if page_icon:
|
||||
icon_type = page_icon["type"]
|
||||
if icon_type in {"external", "file"}:
|
||||
url = page_icon[icon_type]["url"]
|
||||
icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"}
|
||||
else:
|
||||
icon = {"type": icon_type, icon_type: page_icon[icon_type]}
|
||||
else:
|
||||
icon = None
|
||||
parent = database_result["parent"]
|
||||
parent_type = parent["type"]
|
||||
if parent_type == "block_id":
|
||||
parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type])
|
||||
elif parent_type == "workspace":
|
||||
parent_id = "root"
|
||||
else:
|
||||
parent_id = parent[parent_type]
|
||||
page = {
|
||||
"page_id": page_id,
|
||||
"page_name": page_name,
|
||||
"page_icon": icon,
|
||||
"parent_id": parent_id,
|
||||
"type": "database",
|
||||
}
|
||||
pages.append(page)
|
||||
return pages
|
||||
|
||||
def notion_page_search(self, access_token: str):
|
||||
results = []
|
||||
next_cursor = None
|
||||
has_more = True
|
||||
|
||||
while has_more:
|
||||
data: dict[str, Any] = {
|
||||
"filter": {"value": "page", "property": "object"},
|
||||
**({"start_cursor": next_cursor} if next_cursor else {}),
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
|
||||
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
|
||||
results.extend(response_json.get("results", []))
|
||||
|
||||
has_more = response_json.get("has_more", False)
|
||||
next_cursor = response_json.get("next_cursor", None)
|
||||
|
||||
return results
|
||||
|
||||
def notion_block_parent_page_id(self, access_token: str, block_id: str):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
response = httpx.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
message = response_json.get("message", "unknown error")
|
||||
raise ValueError(f"Error fetching block parent page ID: {message}")
|
||||
parent = response_json["parent"]
|
||||
parent_type = parent["type"]
|
||||
if parent_type == "block_id":
|
||||
return self.notion_block_parent_page_id(access_token, parent[parent_type])
|
||||
return parent[parent_type]
|
||||
|
||||
def notion_workspace_name(self, access_token: str):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
response = httpx.get(url=self._NOTION_BOT_USER, headers=headers)
|
||||
response_json = response.json()
|
||||
if "object" in response_json and response_json["object"] == "user":
|
||||
user_type = response_json["type"]
|
||||
user_info = response_json[user_type]
|
||||
if "workspace_name" in user_info:
|
||||
return user_info["workspace_name"]
|
||||
return "workspace"
|
||||
|
||||
def notion_database_search(self, access_token: str):
|
||||
results = []
|
||||
next_cursor = None
|
||||
has_more = True
|
||||
|
||||
while has_more:
|
||||
data: dict[str, Any] = {
|
||||
"filter": {"value": "database", "property": "object"},
|
||||
**({"start_cursor": next_cursor} if next_cursor else {}),
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
|
||||
results.extend(response_json.get("results", []))
|
||||
|
||||
has_more = response_json.get("has_more", False)
|
||||
next_cursor = response_json.get("next_cursor", None)
|
||||
|
||||
return results
|
||||
11
dify/api/libs/orjson.py
Normal file
11
dify/api/libs/orjson.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
|
||||
def orjson_dumps(
|
||||
obj: Any,
|
||||
encoding: str = "utf-8",
|
||||
option: int | None = None,
|
||||
) -> str:
|
||||
return orjson.dumps(obj, option=option).decode(encoding)
|
||||
24
dify/api/libs/passport.py
Normal file
24
dify/api/libs/passport.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import jwt
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class PassportService:
|
||||
def __init__(self):
|
||||
self.sk = dify_config.SECRET_KEY
|
||||
|
||||
def issue(self, payload):
|
||||
return jwt.encode(payload, self.sk, algorithm="HS256")
|
||||
|
||||
def verify(self, token):
|
||||
try:
|
||||
return jwt.decode(token, self.sk, algorithms=["HS256"])
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise Unauthorized("Token has expired.")
|
||||
except jwt.InvalidSignatureError:
|
||||
raise Unauthorized("Invalid token signature.")
|
||||
except jwt.DecodeError:
|
||||
raise Unauthorized("Invalid token.")
|
||||
except jwt.PyJWTError: # Catch-all for other JWT errors
|
||||
raise Unauthorized("Invalid token.")
|
||||
26
dify/api/libs/password.py
Normal file
26
dify/api/libs/password.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import base64
|
||||
import binascii
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$"
|
||||
|
||||
|
||||
def valid_password(password):
|
||||
# Define a regex pattern for password rules
|
||||
pattern = password_pattern
|
||||
# Check if the password matches the pattern
|
||||
if re.match(pattern, password) is not None:
|
||||
return password
|
||||
|
||||
raise ValueError("Password must contain letters and numbers, and the length must be greater than 8.")
|
||||
|
||||
|
||||
def hash_password(password_str, salt_byte):
|
||||
dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000)
|
||||
return binascii.hexlify(dk)
|
||||
|
||||
|
||||
def compare_password(password_str, password_hashed_base64, salt_base64):
|
||||
# compare password for login
|
||||
return hash_password(password_str, base64.b64decode(salt_base64)) == base64.b64decode(password_hashed_base64)
|
||||
94
dify/api/libs/rsa.py
Normal file
94
dify/api/libs/rsa.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import hashlib
|
||||
from typing import Union
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.PublicKey import RSA
|
||||
from Crypto.Random import get_random_bytes
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from libs import gmpy2_pkcs10aep_cipher
|
||||
|
||||
|
||||
def generate_key_pair(tenant_id: str) -> str:
|
||||
private_key = RSA.generate(2048)
|
||||
public_key = private_key.publickey()
|
||||
|
||||
pem_private = private_key.export_key()
|
||||
pem_public = public_key.export_key()
|
||||
|
||||
filepath = f"privkeys/{tenant_id}/private.pem"
|
||||
|
||||
storage.save(filepath, pem_private)
|
||||
|
||||
return pem_public.decode()
|
||||
|
||||
|
||||
prefix_hybrid = b"HYBRID:"
|
||||
|
||||
|
||||
def encrypt(text: str, public_key: Union[str, bytes]) -> bytes:
|
||||
if isinstance(public_key, str):
|
||||
public_key = public_key.encode()
|
||||
|
||||
aes_key = get_random_bytes(16)
|
||||
cipher_aes = AES.new(aes_key, AES.MODE_EAX)
|
||||
|
||||
ciphertext, tag = cipher_aes.encrypt_and_digest(text.encode())
|
||||
|
||||
rsa_key = RSA.import_key(public_key)
|
||||
cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key)
|
||||
|
||||
enc_aes_key: bytes = cipher_rsa.encrypt(aes_key)
|
||||
|
||||
encrypted_data = enc_aes_key + cipher_aes.nonce + tag + ciphertext
|
||||
|
||||
return prefix_hybrid + encrypted_data
|
||||
|
||||
|
||||
def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]:
|
||||
filepath = f"privkeys/{tenant_id}/private.pem"
|
||||
|
||||
cache_key = f"tenant_privkey:{hashlib.sha3_256(filepath.encode()).hexdigest()}"
|
||||
private_key = redis_client.get(cache_key)
|
||||
if not private_key:
|
||||
try:
|
||||
private_key = storage.load(filepath)
|
||||
except FileNotFoundError:
|
||||
raise PrivkeyNotFoundError(f"Private key not found, tenant_id: {tenant_id}")
|
||||
|
||||
redis_client.setex(cache_key, 120, private_key)
|
||||
|
||||
rsa_key = RSA.import_key(private_key)
|
||||
cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key)
|
||||
|
||||
return rsa_key, cipher_rsa
|
||||
|
||||
|
||||
def decrypt_token_with_decoding(encrypted_text: bytes, rsa_key: RSA.RsaKey, cipher_rsa) -> str:
|
||||
if encrypted_text.startswith(prefix_hybrid):
|
||||
encrypted_text = encrypted_text[len(prefix_hybrid) :]
|
||||
|
||||
enc_aes_key = encrypted_text[: rsa_key.size_in_bytes()]
|
||||
nonce = encrypted_text[rsa_key.size_in_bytes() : rsa_key.size_in_bytes() + 16]
|
||||
tag = encrypted_text[rsa_key.size_in_bytes() + 16 : rsa_key.size_in_bytes() + 32]
|
||||
ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32 :]
|
||||
|
||||
aes_key = cipher_rsa.decrypt(enc_aes_key)
|
||||
|
||||
cipher_aes = AES.new(aes_key, AES.MODE_EAX, nonce=nonce)
|
||||
decrypted_text = cipher_aes.decrypt_and_verify(ciphertext, tag)
|
||||
else:
|
||||
decrypted_text = cipher_rsa.decrypt(encrypted_text)
|
||||
|
||||
return decrypted_text.decode()
|
||||
|
||||
|
||||
def decrypt(encrypted_text: bytes, tenant_id: str) -> str:
|
||||
rsa_key, cipher_rsa = get_decrypt_decoding(tenant_id)
|
||||
|
||||
return decrypt_token_with_decoding(encrypted_text=encrypted_text, rsa_key=rsa_key, cipher_rsa=cipher_rsa)
|
||||
|
||||
|
||||
class PrivkeyNotFoundError(Exception):
|
||||
pass
|
||||
108
dify/api/libs/schedule_utils.py
Normal file
108
dify/api/libs/schedule_utils.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytz
|
||||
from croniter import croniter
|
||||
|
||||
|
||||
def calculate_next_run_at(
|
||||
cron_expression: str,
|
||||
timezone: str,
|
||||
base_time: datetime | None = None,
|
||||
) -> datetime:
|
||||
"""
|
||||
Calculate the next run time for a cron expression in a specific timezone.
|
||||
|
||||
Args:
|
||||
cron_expression: Standard 5-field cron expression or predefined expression
|
||||
timezone: Timezone string (e.g., 'UTC', 'America/New_York')
|
||||
base_time: Base time to calculate from (defaults to current UTC time)
|
||||
|
||||
Returns:
|
||||
Next run time in UTC
|
||||
|
||||
Note:
|
||||
Supports enhanced cron syntax including:
|
||||
- Month abbreviations: JAN, FEB, MAR-JUN, JAN,JUN,DEC
|
||||
- Day abbreviations: MON, TUE, MON-FRI, SUN,WED,FRI
|
||||
- Predefined expressions: @daily, @weekly, @monthly, @yearly, @hourly
|
||||
- Special characters: ? wildcard, L (last day), Sunday as 7
|
||||
- Standard 5-field format only (minute hour day month dayOfWeek)
|
||||
"""
|
||||
# Validate cron expression format to match frontend behavior
|
||||
parts = cron_expression.strip().split()
|
||||
|
||||
# Support both 5-field format and predefined expressions (matching frontend)
|
||||
if len(parts) != 5 and not cron_expression.startswith("@"):
|
||||
raise ValueError(
|
||||
f"Cron expression must have exactly 5 fields or be a predefined expression "
|
||||
f"(@daily, @weekly, etc.). Got {len(parts)} fields: '{cron_expression}'"
|
||||
)
|
||||
|
||||
tz = pytz.timezone(timezone)
|
||||
|
||||
if base_time is None:
|
||||
base_time = datetime.now(UTC)
|
||||
|
||||
base_time_tz = base_time.astimezone(tz)
|
||||
cron = croniter(cron_expression, base_time_tz)
|
||||
next_run_tz = cron.get_next(datetime)
|
||||
next_run_utc = next_run_tz.astimezone(UTC)
|
||||
|
||||
return next_run_utc
|
||||
|
||||
|
||||
def convert_12h_to_24h(time_str: str) -> tuple[int, int]:
|
||||
"""
|
||||
Parse 12-hour time format to 24-hour format for cron compatibility.
|
||||
|
||||
Args:
|
||||
time_str: Time string in format "HH:MM AM/PM" (e.g., "12:30 PM")
|
||||
|
||||
Returns:
|
||||
Tuple of (hour, minute) in 24-hour format
|
||||
|
||||
Raises:
|
||||
ValueError: If time string format is invalid or values are out of range
|
||||
|
||||
Examples:
|
||||
- "12:00 AM" -> (0, 0) # Midnight
|
||||
- "12:00 PM" -> (12, 0) # Noon
|
||||
- "1:30 PM" -> (13, 30)
|
||||
- "11:59 PM" -> (23, 59)
|
||||
"""
|
||||
if not time_str or not time_str.strip():
|
||||
raise ValueError("Time string cannot be empty")
|
||||
|
||||
parts = time_str.strip().split()
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Invalid time format: '{time_str}'. Expected 'HH:MM AM/PM'")
|
||||
|
||||
time_part, period = parts
|
||||
period = period.upper()
|
||||
|
||||
if period not in ["AM", "PM"]:
|
||||
raise ValueError(f"Invalid period: '{period}'. Must be 'AM' or 'PM'")
|
||||
|
||||
time_parts = time_part.split(":")
|
||||
if len(time_parts) != 2:
|
||||
raise ValueError(f"Invalid time format: '{time_part}'. Expected 'HH:MM'")
|
||||
|
||||
try:
|
||||
hour = int(time_parts[0])
|
||||
minute = int(time_parts[1])
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid time values: {e}")
|
||||
|
||||
if hour < 1 or hour > 12:
|
||||
raise ValueError(f"Invalid hour: {hour}. Must be between 1 and 12")
|
||||
|
||||
if minute < 0 or minute > 59:
|
||||
raise ValueError(f"Invalid minute: {minute}. Must be between 0 and 59")
|
||||
|
||||
# Handle 12-hour to 24-hour edge cases
|
||||
if period == "PM" and hour != 12:
|
||||
hour += 12
|
||||
elif period == "AM" and hour == 12:
|
||||
hour = 0
|
||||
|
||||
return hour, minute
|
||||
47
dify/api/libs/sendgrid.py
Normal file
47
dify/api/libs/sendgrid.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import logging
|
||||
|
||||
import sendgrid
|
||||
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
|
||||
from sendgrid.helpers.mail import Content, Email, Mail, To
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SendGridClient:
|
||||
def __init__(self, sendgrid_api_key: str, _from: str):
|
||||
self.sendgrid_api_key = sendgrid_api_key
|
||||
self._from = _from
|
||||
|
||||
def send(self, mail: dict):
|
||||
logger.debug("Sending email with SendGrid")
|
||||
_to = ""
|
||||
try:
|
||||
_to = mail["to"]
|
||||
|
||||
if not _to:
|
||||
raise ValueError("SendGridClient: Cannot send email: recipient address is missing.")
|
||||
|
||||
sg = sendgrid.SendGridAPIClient(api_key=self.sendgrid_api_key)
|
||||
from_email = Email(self._from)
|
||||
to_email = To(_to)
|
||||
subject = mail["subject"]
|
||||
content = Content("text/html", mail["html"])
|
||||
sg_mail = Mail(from_email, to_email, subject, content)
|
||||
mail_json = sg_mail.get()
|
||||
response = sg.client.mail.send.post(request_body=mail_json) # type: ignore
|
||||
logger.debug(response.status_code)
|
||||
logger.debug(response.body)
|
||||
logger.debug(response.headers)
|
||||
|
||||
except TimeoutError:
|
||||
logger.exception("SendGridClient Timeout occurred while sending email")
|
||||
raise
|
||||
except (UnauthorizedError, ForbiddenError):
|
||||
logger.exception(
|
||||
"SendGridClient Authentication failed. "
|
||||
"Verify that your credentials and the 'from' email address are correct"
|
||||
)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("SendGridClient Unexpected error occurred while sending email to %s", _to)
|
||||
raise
|
||||
59
dify/api/libs/smtp.py
Normal file
59
dify/api/libs/smtp.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import logging
|
||||
import smtplib
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SMTPClient:
|
||||
def __init__(
|
||||
self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False
|
||||
):
|
||||
self.server = server
|
||||
self.port = port
|
||||
self._from = _from
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.use_tls = use_tls
|
||||
self.opportunistic_tls = opportunistic_tls
|
||||
|
||||
def send(self, mail: dict):
|
||||
smtp = None
|
||||
try:
|
||||
if self.use_tls:
|
||||
if self.opportunistic_tls:
|
||||
smtp = smtplib.SMTP(self.server, self.port, timeout=10)
|
||||
# Send EHLO command with the HELO domain name as the server address
|
||||
smtp.ehlo(self.server)
|
||||
smtp.starttls()
|
||||
# Resend EHLO command to identify the TLS session
|
||||
smtp.ehlo(self.server)
|
||||
else:
|
||||
smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10)
|
||||
else:
|
||||
smtp = smtplib.SMTP(self.server, self.port, timeout=10)
|
||||
|
||||
# Only authenticate if both username and password are non-empty
|
||||
if self.username and self.password and self.username.strip() and self.password.strip():
|
||||
smtp.login(self.username, self.password)
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = mail["subject"]
|
||||
msg["From"] = self._from
|
||||
msg["To"] = mail["to"]
|
||||
msg.attach(MIMEText(mail["html"], "html"))
|
||||
|
||||
smtp.sendmail(self._from, mail["to"], msg.as_string())
|
||||
except smtplib.SMTPException:
|
||||
logger.exception("SMTP error occurred")
|
||||
raise
|
||||
except TimeoutError:
|
||||
logger.exception("Timeout occurred while sending email")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Unexpected error occurred while sending email to %s", mail["to"])
|
||||
raise
|
||||
finally:
|
||||
if smtp:
|
||||
smtp.quit()
|
||||
67
dify/api/libs/time_parser.py
Normal file
67
dify/api/libs/time_parser.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Time duration parser utility."""
|
||||
|
||||
import re
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
|
||||
def parse_time_duration(duration_str: str) -> timedelta | None:
|
||||
"""
|
||||
Parse time duration string to timedelta.
|
||||
|
||||
Supported formats:
|
||||
- 7d: 7 days
|
||||
- 4h: 4 hours
|
||||
- 30m: 30 minutes
|
||||
- 30s: 30 seconds
|
||||
|
||||
Args:
|
||||
duration_str: Duration string (e.g., "7d", "4h", "30m", "30s")
|
||||
|
||||
Returns:
|
||||
timedelta object or None if invalid format
|
||||
"""
|
||||
if not duration_str:
|
||||
return None
|
||||
|
||||
# Pattern: number followed by unit (d, h, m, s)
|
||||
pattern = r"^(\d+)([dhms])$"
|
||||
match = re.match(pattern, duration_str.lower())
|
||||
|
||||
if not match:
|
||||
return None
|
||||
|
||||
value = int(match.group(1))
|
||||
unit = match.group(2)
|
||||
|
||||
if unit == "d":
|
||||
return timedelta(days=value)
|
||||
elif unit == "h":
|
||||
return timedelta(hours=value)
|
||||
elif unit == "m":
|
||||
return timedelta(minutes=value)
|
||||
elif unit == "s":
|
||||
return timedelta(seconds=value)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_time_threshold(duration_str: str | None) -> datetime | None:
|
||||
"""
|
||||
Get datetime threshold from duration string.
|
||||
|
||||
Calculates the datetime that is duration_str ago from now.
|
||||
|
||||
Args:
|
||||
duration_str: Duration string (e.g., "7d", "4h", "30m", "30s")
|
||||
|
||||
Returns:
|
||||
datetime object representing the threshold time, or None if no duration
|
||||
"""
|
||||
if not duration_str:
|
||||
return None
|
||||
|
||||
duration = parse_time_duration(duration_str)
|
||||
if duration is None:
|
||||
return None
|
||||
|
||||
return datetime.now(UTC) - duration
|
||||
231
dify/api/libs/token.py
Normal file
231
dify/api/libs/token.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import logging
|
||||
import re
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from flask import Request
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
from werkzeug.wrappers import Response
|
||||
|
||||
from configs import dify_config
|
||||
from constants import (
|
||||
COOKIE_NAME_ACCESS_TOKEN,
|
||||
COOKIE_NAME_CSRF_TOKEN,
|
||||
COOKIE_NAME_PASSPORT,
|
||||
COOKIE_NAME_REFRESH_TOKEN,
|
||||
COOKIE_NAME_WEBAPP_ACCESS_TOKEN,
|
||||
HEADER_NAME_CSRF_TOKEN,
|
||||
HEADER_NAME_PASSPORT,
|
||||
)
|
||||
from libs.passport import PassportService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CSRF_WHITE_LIST = [
|
||||
re.compile(r"/console/api/apps/[a-f0-9-]+/workflows/draft"),
|
||||
]
|
||||
|
||||
|
||||
# server is behind a reverse proxy, so we need to check the url
|
||||
def is_secure() -> bool:
|
||||
return dify_config.CONSOLE_WEB_URL.startswith("https") and dify_config.CONSOLE_API_URL.startswith("https")
|
||||
|
||||
|
||||
def _cookie_domain() -> str | None:
|
||||
"""
|
||||
Returns the normalized cookie domain.
|
||||
|
||||
Leading dots are stripped from the configured domain. Historically, a leading dot
|
||||
indicated that a cookie should be sent to all subdomains, but modern browsers treat
|
||||
'example.com' and '.example.com' identically. This normalization ensures consistent
|
||||
behavior and avoids confusion.
|
||||
"""
|
||||
domain = dify_config.COOKIE_DOMAIN.strip()
|
||||
domain = domain.removeprefix(".")
|
||||
return domain or None
|
||||
|
||||
|
||||
def _real_cookie_name(cookie_name: str) -> str:
|
||||
if is_secure() and _cookie_domain() is None:
|
||||
return "__Host-" + cookie_name
|
||||
else:
|
||||
return cookie_name
|
||||
|
||||
|
||||
def _try_extract_from_header(request: Request) -> str | None:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header:
|
||||
if " " not in auth_header:
|
||||
return None
|
||||
else:
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != "bearer":
|
||||
return None
|
||||
else:
|
||||
return auth_token
|
||||
return None
|
||||
|
||||
|
||||
def extract_refresh_token(request: Request) -> str | None:
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_REFRESH_TOKEN))
|
||||
|
||||
|
||||
def extract_csrf_token(request: Request) -> str | None:
|
||||
return request.headers.get(HEADER_NAME_CSRF_TOKEN)
|
||||
|
||||
|
||||
def extract_csrf_token_from_cookie(request: Request) -> str | None:
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
|
||||
|
||||
|
||||
def extract_access_token(request: Request) -> str | None:
|
||||
def _try_extract_from_cookie(request: Request) -> str | None:
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
|
||||
|
||||
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
|
||||
|
||||
|
||||
def extract_webapp_access_token(request: Request) -> str | None:
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request)
|
||||
|
||||
|
||||
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
|
||||
def _try_extract_passport_token_from_cookie(request: Request) -> str | None:
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code))
|
||||
|
||||
def _try_extract_passport_token_from_header(request: Request) -> str | None:
|
||||
return request.headers.get(HEADER_NAME_PASSPORT)
|
||||
|
||||
ret = _try_extract_passport_token_from_cookie(request) or _try_extract_passport_token_from_header(request)
|
||||
return ret
|
||||
|
||||
|
||||
def set_access_token_to_cookie(request: Request, response: Response, token: str, samesite: str = "Lax"):
|
||||
response.set_cookie(
|
||||
_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN),
|
||||
value=token,
|
||||
httponly=True,
|
||||
domain=_cookie_domain(),
|
||||
secure=is_secure(),
|
||||
samesite=samesite,
|
||||
max_age=int(dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 60),
|
||||
path="/",
|
||||
)
|
||||
|
||||
|
||||
def set_refresh_token_to_cookie(request: Request, response: Response, token: str):
|
||||
response.set_cookie(
|
||||
_real_cookie_name(COOKIE_NAME_REFRESH_TOKEN),
|
||||
value=token,
|
||||
httponly=True,
|
||||
domain=_cookie_domain(),
|
||||
secure=is_secure(),
|
||||
samesite="Lax",
|
||||
max_age=int(60 * 60 * 24 * dify_config.REFRESH_TOKEN_EXPIRE_DAYS),
|
||||
path="/",
|
||||
)
|
||||
|
||||
|
||||
def set_csrf_token_to_cookie(request: Request, response: Response, token: str):
|
||||
response.set_cookie(
|
||||
_real_cookie_name(COOKIE_NAME_CSRF_TOKEN),
|
||||
value=token,
|
||||
httponly=False,
|
||||
domain=_cookie_domain(),
|
||||
secure=is_secure(),
|
||||
samesite="Lax",
|
||||
max_age=int(60 * dify_config.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
path="/",
|
||||
)
|
||||
|
||||
|
||||
def _clear_cookie(
|
||||
response: Response,
|
||||
cookie_name: str,
|
||||
samesite: str = "Lax",
|
||||
http_only: bool = True,
|
||||
):
|
||||
response.set_cookie(
|
||||
_real_cookie_name(cookie_name),
|
||||
"",
|
||||
expires=0,
|
||||
path="/",
|
||||
domain=_cookie_domain(),
|
||||
secure=is_secure(),
|
||||
httponly=http_only,
|
||||
samesite=samesite,
|
||||
)
|
||||
|
||||
|
||||
def clear_access_token_from_cookie(response: Response, samesite: str = "Lax"):
|
||||
_clear_cookie(response, COOKIE_NAME_ACCESS_TOKEN, samesite)
|
||||
|
||||
|
||||
def clear_webapp_access_token_from_cookie(response: Response, samesite: str = "Lax"):
|
||||
_clear_cookie(response, COOKIE_NAME_WEBAPP_ACCESS_TOKEN, samesite)
|
||||
|
||||
|
||||
def clear_refresh_token_from_cookie(response: Response):
|
||||
_clear_cookie(response, COOKIE_NAME_REFRESH_TOKEN)
|
||||
|
||||
|
||||
def clear_csrf_token_from_cookie(response: Response):
|
||||
_clear_cookie(response, COOKIE_NAME_CSRF_TOKEN, http_only=False)
|
||||
|
||||
|
||||
def build_force_logout_cookie_headers() -> list[str]:
|
||||
"""
|
||||
Generate Set-Cookie header values that clear all auth-related cookies.
|
||||
This mirrors the behavior of the standard cookie clearing helpers while
|
||||
allowing callers that do not have a Response instance to reuse the logic.
|
||||
"""
|
||||
response = Response()
|
||||
clear_access_token_from_cookie(response)
|
||||
clear_csrf_token_from_cookie(response)
|
||||
clear_refresh_token_from_cookie(response)
|
||||
return response.headers.getlist("Set-Cookie")
|
||||
|
||||
|
||||
def check_csrf_token(request: Request, user_id: str):
|
||||
# some apis are sent by beacon, so we need to bypass csrf token check
|
||||
# since these APIs are post, they are already protected by SameSite: Lax, so csrf is not required.
|
||||
def _unauthorized():
|
||||
raise Unauthorized("CSRF token is missing or invalid.")
|
||||
|
||||
for pattern in CSRF_WHITE_LIST:
|
||||
if pattern.match(request.path):
|
||||
return
|
||||
|
||||
csrf_token = extract_csrf_token(request)
|
||||
csrf_token_from_cookie = extract_csrf_token_from_cookie(request)
|
||||
|
||||
if csrf_token != csrf_token_from_cookie:
|
||||
_unauthorized()
|
||||
|
||||
if not csrf_token:
|
||||
_unauthorized()
|
||||
verified = {}
|
||||
try:
|
||||
verified = PassportService().verify(csrf_token)
|
||||
except:
|
||||
_unauthorized()
|
||||
|
||||
if verified.get("sub") != user_id:
|
||||
_unauthorized()
|
||||
|
||||
exp: int | None = verified.get("exp")
|
||||
if not exp:
|
||||
_unauthorized()
|
||||
else:
|
||||
time_now = int(datetime.now().timestamp())
|
||||
if exp < time_now:
|
||||
_unauthorized()
|
||||
|
||||
|
||||
def generate_csrf_token(user_id: str) -> str:
|
||||
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
payload = {
|
||||
"exp": int(exp_dt.timestamp()),
|
||||
"sub": user_id,
|
||||
}
|
||||
return PassportService().issue(payload)
|
||||
9
dify/api/libs/typing.py
Normal file
9
dify/api/libs/typing.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from typing import TypeGuard
|
||||
|
||||
|
||||
def is_str_dict(v: object) -> TypeGuard[dict[str, object]]:
|
||||
return isinstance(v, dict)
|
||||
|
||||
|
||||
def is_str(v: object) -> TypeGuard[str]:
|
||||
return isinstance(v, str)
|
||||
164
dify/api/libs/uuid_utils.py
Normal file
164
dify/api/libs/uuid_utils.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import secrets
|
||||
import struct
|
||||
import time
|
||||
import uuid
|
||||
|
||||
# Reference for UUIDv7 specification:
|
||||
# RFC 9562, Section 5.7 - https://www.rfc-editor.org/rfc/rfc9562.html#section-5.7
|
||||
|
||||
# Define the format for packing the timestamp as an unsigned 64-bit integer (big-endian).
|
||||
#
|
||||
# For details on the `struct.pack` format, refer to:
|
||||
# https://docs.python.org/3/library/struct.html#byte-order-size-and-alignment
|
||||
_PACK_TIMESTAMP = ">Q"
|
||||
|
||||
# Define the format for packing the 12-bit random data A (as specified in RFC 9562 Section 5.7)
|
||||
# into an unsigned 16-bit integer (big-endian).
|
||||
_PACK_RAND_A = ">H"
|
||||
|
||||
|
||||
def _create_uuidv7_bytes(timestamp_ms: int, random_bytes: bytes) -> bytes:
|
||||
"""Create UUIDv7 byte structure with given timestamp and random bytes.
|
||||
|
||||
This is a private helper function that handles the common logic for creating
|
||||
UUIDv7 byte structure according to RFC 9562 specification.
|
||||
|
||||
UUIDv7 Structure:
|
||||
- 48 bits: timestamp (milliseconds since Unix epoch)
|
||||
- 12 bits: random data A (with version bits)
|
||||
- 62 bits: random data B (with variant bits)
|
||||
|
||||
The function performs the following operations:
|
||||
1. Creates a 128-bit (16-byte) UUID structure
|
||||
2. Packs the timestamp into the first 48 bits (6 bytes)
|
||||
3. Sets the version bits to 7 (0111) in the correct position
|
||||
4. Sets the variant bits to 10 (binary) in the correct position
|
||||
5. Fills the remaining bits with the provided random bytes
|
||||
|
||||
Args:
|
||||
timestamp_ms: The timestamp in milliseconds since Unix epoch (48 bits).
|
||||
random_bytes: Random bytes to use for the random portions (must be 10 bytes).
|
||||
First 2 bytes are used for random data A (12 bits after version).
|
||||
Last 8 bytes are used for random data B (62 bits after variant).
|
||||
|
||||
Returns:
|
||||
A 16-byte bytes object representing the complete UUIDv7 structure.
|
||||
|
||||
Note:
|
||||
This function assumes the random_bytes parameter is exactly 10 bytes.
|
||||
The caller is responsible for providing appropriate random data.
|
||||
"""
|
||||
# Create the 128-bit UUID structure
|
||||
uuid_bytes = bytearray(16)
|
||||
|
||||
# Pack timestamp (48 bits) into first 6 bytes
|
||||
uuid_bytes[0:6] = struct.pack(_PACK_TIMESTAMP, timestamp_ms)[2:8] # Take last 6 bytes of 8-byte big-endian
|
||||
|
||||
# Next 16 bits: random data A (12 bits) + version (4 bits)
|
||||
# Take first 2 random bytes and set version to 7
|
||||
rand_a = struct.unpack(_PACK_RAND_A, random_bytes[0:2])[0]
|
||||
# Clear the highest 4 bits to make room for the version field
|
||||
# by performing a bitwise AND with 0x0FFF (binary: 0b0000_1111_1111_1111).
|
||||
rand_a = rand_a & 0x0FFF
|
||||
# Set the version field to 7 (binary: 0111) by performing a bitwise OR with 0x7000 (binary: 0b0111_0000_0000_0000).
|
||||
rand_a = rand_a | 0x7000
|
||||
uuid_bytes[6:8] = struct.pack(_PACK_RAND_A, rand_a)
|
||||
|
||||
# Last 64 bits: random data B (62 bits) + variant (2 bits)
|
||||
# Use remaining 8 random bytes and set variant to 10 (binary)
|
||||
uuid_bytes[8:16] = random_bytes[2:10]
|
||||
|
||||
# Set variant bits (first 2 bits of byte 8 should be '10')
|
||||
uuid_bytes[8] = (uuid_bytes[8] & 0x3F) | 0x80 # Set variant to 10xxxxxx
|
||||
|
||||
return bytes(uuid_bytes)
|
||||
|
||||
|
||||
def uuidv7(timestamp_ms: int | None = None) -> uuid.UUID:
|
||||
"""Generate a UUID version 7 according to RFC 9562 specification.
|
||||
|
||||
UUIDv7 features a time-ordered value field derived from the widely
|
||||
implemented and well known Unix Epoch timestamp source, the number of
|
||||
milliseconds since midnight 1 Jan 1970 UTC, leap seconds excluded.
|
||||
|
||||
Structure:
|
||||
- 48 bits: timestamp (milliseconds since Unix epoch)
|
||||
- 12 bits: random data A (with version bits)
|
||||
- 62 bits: random data B (with variant bits)
|
||||
|
||||
Args:
|
||||
timestamp_ms: The timestamp used when generating UUID, use the current time if unspecified.
|
||||
Should be an integer representing milliseconds since Unix epoch.
|
||||
|
||||
Returns:
|
||||
A UUID object representing a UUIDv7.
|
||||
|
||||
Example:
|
||||
>>> import time
|
||||
>>> # Generate UUIDv7 with current time
|
||||
>>> uuid_current = uuidv7()
|
||||
>>> # Generate UUIDv7 with specific timestamp
|
||||
>>> uuid_specific = uuidv7(int(time.time() * 1000))
|
||||
"""
|
||||
if timestamp_ms is None:
|
||||
timestamp_ms = int(time.time() * 1000)
|
||||
|
||||
# Generate 10 random bytes for the random portions
|
||||
random_bytes = secrets.token_bytes(10)
|
||||
|
||||
# Create UUIDv7 bytes using the helper function
|
||||
uuid_bytes = _create_uuidv7_bytes(timestamp_ms, random_bytes)
|
||||
|
||||
return uuid.UUID(bytes=uuid_bytes)
|
||||
|
||||
|
||||
def uuidv7_timestamp(id_: uuid.UUID) -> int:
|
||||
"""Extract the timestamp from a UUIDv7.
|
||||
|
||||
UUIDv7 contains a 48-bit timestamp field representing milliseconds since
|
||||
the Unix epoch (1970-01-01 00:00:00 UTC). This function extracts and
|
||||
returns that timestamp as an integer representing milliseconds since the epoch.
|
||||
|
||||
Args:
|
||||
id_: A UUID object that should be a UUIDv7 (version 7).
|
||||
|
||||
Returns:
|
||||
The timestamp as an integer representing milliseconds since Unix epoch.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided UUID is not version 7.
|
||||
|
||||
Example:
|
||||
>>> uuid_v7 = uuidv7()
|
||||
>>> timestamp = uuidv7_timestamp(uuid_v7)
|
||||
>>> print(f"UUID was created at: {timestamp} ms")
|
||||
"""
|
||||
# Verify this is a UUIDv7
|
||||
if id_.version != 7:
|
||||
raise ValueError(f"Expected UUIDv7 (version 7), got version {id_.version}")
|
||||
|
||||
# Extract the UUID bytes
|
||||
uuid_bytes = id_.bytes
|
||||
|
||||
# Extract the first 48 bits (6 bytes) as the timestamp in milliseconds
|
||||
# Pad with 2 zero bytes at the beginning to make it 8 bytes for unpacking as Q (unsigned long long)
|
||||
timestamp_bytes = b"\x00\x00" + uuid_bytes[0:6]
|
||||
ts_in_ms = struct.unpack(_PACK_TIMESTAMP, timestamp_bytes)[0]
|
||||
|
||||
# Return timestamp directly in milliseconds as integer
|
||||
assert isinstance(ts_in_ms, int)
|
||||
return ts_in_ms
|
||||
|
||||
|
||||
def uuidv7_boundary(timestamp_ms: int) -> uuid.UUID:
|
||||
"""Generate a non-random uuidv7 with the given timestamp (first 48 bits) and
|
||||
all random bits to 0. As the smallest possible uuidv7 for that timestamp,
|
||||
it may be used as a boundary for partitions.
|
||||
"""
|
||||
# Use zero bytes for all random portions
|
||||
zero_random_bytes = b"\x00" * 10
|
||||
|
||||
# Create UUIDv7 bytes using the helper function
|
||||
uuid_bytes = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes)
|
||||
|
||||
return uuid.UUID(bytes=uuid_bytes)
|
||||
5
dify/api/libs/validators.py
Normal file
5
dify/api/libs/validators.py
Normal file
@@ -0,0 +1,5 @@
|
||||
def validate_description_length(description: str | None) -> str | None:
|
||||
"""Validate description length."""
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
Reference in New Issue
Block a user