This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

View File

@@ -0,0 +1,134 @@
"""
Broadcast channel for Pub/Sub messaging.
"""
import types
from abc import abstractmethod
from collections.abc import Iterator
from contextlib import AbstractContextManager
from typing import Protocol, Self
class Subscription(AbstractContextManager["Subscription"], Protocol):
"""A subscription to a topic that provides an iterator over received messages.
The subscription can be used as a context manager and will automatically
close when exiting the context.
Note: `Subscription` instances are not thread-safe. Each thread should create its own
subscription.
"""
@abstractmethod
def __iter__(self) -> Iterator[bytes]:
"""`__iter__` returns an iterator used to consume the message from this subscription.
If the caller did not enter the context, `__iter__` may lazily perform the setup before
yielding messages; otherwise `__enter__` handles it.”
If the subscription is closed, then the returned iterator exits without
raising any error.
"""
...
@abstractmethod
def close(self) -> None:
"""close closes the subscription, releases any resources associated with it."""
...
def __enter__(self) -> Self:
"""`__enter__` does the setup logic of the subscription (if any), and return itself."""
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
self.close()
return None
@abstractmethod
def receive(self, timeout: float | None = 0.1) -> bytes | None:
"""Receive the next message from the broadcast channel.
If `timeout` is specified, this method returns `None` if no message is
received within the given period. If `timeout` is `None`, the call blocks
until a message is received.
Calling receive with `timeout=None` is highly discouraged, as it is impossible to
cancel a blocking subscription.
:param timeout: timeout for receive message, in seconds.
Returns:
bytes: The received message as a byte string, or
None: If the timeout expires before a message is received.
Raises:
SubscriptionClosed: If the subscription has already been closed.
"""
...
class Producer(Protocol):
"""Producer is an interface for message publishing. It is already bound to a specific topic.
`Producer` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def publish(self, payload: bytes) -> None:
"""Publish a message to the bounded topic."""
...
class Subscriber(Protocol):
"""Subscriber is an interface for subscription creation. It is already bound to a specific topic.
`Subscriber` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def subscribe(self) -> Subscription:
pass
class Topic(Producer, Subscriber, Protocol):
"""A named channel for publishing and subscribing to messages.
Topics provide both read and write access. For restricted access,
use as_producer() for write-only view or as_subscriber() for read-only view.
`Topic` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def as_producer(self) -> Producer:
"""as_producer creates a write-only view for this topic."""
...
@abstractmethod
def as_subscriber(self) -> Subscriber:
"""as_subscriber create a read-only view for this topic."""
...
class BroadcastChannel(Protocol):
"""A broadcasting channel is a channel supporting broadcasting semantics.
Each channel is identified by a topic, different topics are isolated and do not affect each other.
There can be multiple subscriptions to a specific topic. When a publisher publishes a message to
a specific topic, all subscription should receive the published message.
There are no restriction for the persistence of messages. Once a subscription is created, it
should receive all subsequent messages published.
`BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def topic(self, topic: str) -> "Topic":
"""topic returns a `Topic` instance for the given topic name."""
...

View File

@@ -0,0 +1,12 @@
class BroadcastChannelError(Exception):
"""`BroadcastChannelError` is the base class for all exceptions related
to `BroadcastChannel`."""
pass
class SubscriptionClosedError(BroadcastChannelError):
"""SubscriptionClosedError means that the subscription has been closed and
methods for consuming messages should not be called."""
pass

View File

@@ -0,0 +1,4 @@
from .channel import BroadcastChannel
from .sharded_channel import ShardedRedisBroadcastChannel
__all__ = ["BroadcastChannel", "ShardedRedisBroadcastChannel"]

View File

@@ -0,0 +1,227 @@
import logging
import queue
import threading
import types
from collections.abc import Generator, Iterator
from typing import Self
from libs.broadcast_channel.channel import Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis.client import PubSub
_logger = logging.getLogger(__name__)
class RedisSubscriptionBase(Subscription):
"""Base class for Redis pub/sub subscriptions with common functionality.
This class provides shared functionality for both regular and sharded
Redis pub/sub subscriptions, reducing code duplication and improving
maintainability.
"""
def __init__(
self,
pubsub: PubSub,
topic: str,
):
# The _pubsub is None only if the subscription is closed.
self._pubsub: PubSub | None = pubsub
self._topic = topic
self._closed = threading.Event()
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
self._dropped_count = 0
self._listener_thread: threading.Thread | None = None
self._start_lock = threading.Lock()
self._started = False
def _start_if_needed(self) -> None:
"""Start the subscription if not already started."""
with self._start_lock:
if self._started:
return
if self._closed.is_set():
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
if self._pubsub is None:
raise SubscriptionClosedError(
f"The Redis {self._get_subscription_type()} subscription has been cleaned up"
)
self._subscribe()
_logger.debug("Subscribed to %s channel %s", self._get_subscription_type(), self._topic)
self._listener_thread = threading.Thread(
target=self._listen,
name=f"redis-{self._get_subscription_type().replace(' ', '-')}-broadcast-{self._topic}",
daemon=True,
)
self._listener_thread.start()
self._started = True
def _listen(self) -> None:
"""Main listener loop for processing messages."""
pubsub = self._pubsub
assert pubsub is not None, "PubSub should not be None while starting listening."
while not self._closed.is_set():
try:
raw_message = self._get_message()
except Exception as e:
# Log the exception and exit the listener thread gracefully
# This handles Redis connection errors and other exceptions
_logger.error(
"Error getting message from Redis %s subscription, topic=%s: %s",
self._get_subscription_type(),
self._topic,
e,
exc_info=True,
)
break
if raw_message is None:
continue
if raw_message.get("type") != self._get_message_type():
continue
channel_field = raw_message.get("channel")
if isinstance(channel_field, bytes):
channel_name = channel_field.decode("utf-8")
elif isinstance(channel_field, str):
channel_name = channel_field
else:
channel_name = str(channel_field)
if channel_name != self._topic:
_logger.warning(
"Ignoring %s message from unexpected channel %s", self._get_subscription_type(), channel_name
)
continue
payload_bytes: bytes | None = raw_message.get("data")
if not isinstance(payload_bytes, bytes):
_logger.error(
"Received invalid data from %s channel %s, type=%s",
self._get_subscription_type(),
self._topic,
type(payload_bytes),
)
continue
self._enqueue_message(payload_bytes)
_logger.debug("%s listener thread stopped for channel %s", self._get_subscription_type().title(), self._topic)
try:
self._unsubscribe()
pubsub.close()
_logger.debug("%s PubSub closed for topic %s", self._get_subscription_type().title(), self._topic)
except Exception as e:
_logger.error(
"Error during cleanup of Redis %s subscription, topic=%s: %s",
self._get_subscription_type(),
self._topic,
e,
exc_info=True,
)
finally:
self._pubsub = None
def _enqueue_message(self, payload: bytes) -> None:
"""Enqueue a message to the internal queue with dropping behavior."""
while not self._closed.is_set():
try:
self._queue.put_nowait(payload)
return
except queue.Full:
try:
self._queue.get_nowait()
self._dropped_count += 1
_logger.debug(
"Dropped message from Redis %s subscription, topic=%s, total_dropped=%d",
self._get_subscription_type(),
self._topic,
self._dropped_count,
)
except queue.Empty:
continue
return
def _message_iterator(self) -> Generator[bytes, None, None]:
"""Iterator for consuming messages from the subscription."""
while not self._closed.is_set():
try:
item = self._queue.get(timeout=0.1)
except queue.Empty:
continue
yield item
def __iter__(self) -> Iterator[bytes]:
"""Return an iterator over messages from the subscription."""
if self._closed.is_set():
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
self._start_if_needed()
return iter(self._message_iterator())
def receive(self, timeout: float | None = None) -> bytes | None:
"""Receive the next message from the subscription."""
if self._closed.is_set():
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
self._start_if_needed()
try:
item = self._queue.get(timeout=timeout)
except queue.Empty:
return None
return item
def __enter__(self) -> Self:
"""Context manager entry point."""
self._start_if_needed()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
"""Context manager exit point."""
self.close()
return None
def close(self) -> None:
"""Close the subscription and clean up resources."""
if self._closed.is_set():
return
self._closed.set()
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the
# message retrieval method should NOT be called concurrently.
#
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
listener = self._listener_thread
if listener is not None:
listener.join(timeout=1.0)
self._listener_thread = None
# Abstract methods to be implemented by subclasses
def _get_subscription_type(self) -> str:
"""Return the subscription type (e.g., 'regular' or 'sharded')."""
raise NotImplementedError
def _subscribe(self) -> None:
"""Subscribe to the Redis topic using the appropriate command."""
raise NotImplementedError
def _unsubscribe(self) -> None:
"""Unsubscribe from the Redis topic using the appropriate command."""
raise NotImplementedError
def _get_message(self) -> dict | None:
"""Get a message from Redis using the appropriate method."""
raise NotImplementedError
def _get_message_type(self) -> str:
"""Return the expected message type (e.g., 'message' or 'smessage')."""
raise NotImplementedError

View File

@@ -0,0 +1,67 @@
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis
from ._subscription import RedisSubscriptionBase
class BroadcastChannel:
"""
Redis Pub/Sub based broadcast channel implementation (regular, non-sharded).
Provides "at most once" delivery semantics for messages published to channels
using Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
"""
def __init__(
self,
redis_client: Redis,
):
self._client = redis_client
def topic(self, topic: str) -> "Topic":
return Topic(self._client, topic)
class Topic:
def __init__(self, redis_client: Redis, topic: str):
self._client = redis_client
self._topic = topic
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.publish(self._topic, payload)
def as_subscriber(self) -> Subscriber:
return self
def subscribe(self) -> Subscription:
return _RedisSubscription(
pubsub=self._client.pubsub(),
topic=self._topic,
)
class _RedisSubscription(RedisSubscriptionBase):
"""Regular Redis pub/sub subscription implementation."""
def _get_subscription_type(self) -> str:
return "regular"
def _subscribe(self) -> None:
assert self._pubsub is not None
self._pubsub.subscribe(self._topic)
def _unsubscribe(self) -> None:
assert self._pubsub is not None
self._pubsub.unsubscribe(self._topic)
def _get_message(self) -> dict | None:
assert self._pubsub is not None
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
def _get_message_type(self) -> str:
return "message"

View File

@@ -0,0 +1,65 @@
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis
from ._subscription import RedisSubscriptionBase
class ShardedRedisBroadcastChannel:
"""
Redis 7.0+ Sharded Pub/Sub based broadcast channel implementation.
Provides "at most once" delivery semantics using SPUBLISH/SSUBSCRIBE commands,
distributing channels across Redis cluster nodes for better scalability.
"""
def __init__(
self,
redis_client: Redis,
):
self._client = redis_client
def topic(self, topic: str) -> "ShardedTopic":
return ShardedTopic(self._client, topic)
class ShardedTopic:
def __init__(self, redis_client: Redis, topic: str):
self._client = redis_client
self._topic = topic
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.spublish(self._topic, payload) # type: ignore[attr-defined]
def as_subscriber(self) -> Subscriber:
return self
def subscribe(self) -> Subscription:
return _RedisShardedSubscription(
pubsub=self._client.pubsub(),
topic=self._topic,
)
class _RedisShardedSubscription(RedisSubscriptionBase):
"""Redis 7.0+ sharded pub/sub subscription implementation."""
def _get_subscription_type(self) -> str:
return "sharded"
def _subscribe(self) -> None:
assert self._pubsub is not None
self._pubsub.ssubscribe(self._topic) # type: ignore[attr-defined]
def _unsubscribe(self) -> None:
assert self._pubsub is not None
self._pubsub.sunsubscribe(self._topic) # type: ignore[attr-defined]
def _get_message(self) -> dict | None:
assert self._pubsub is not None
return self._pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=0.1) # type: ignore[attr-defined]
def _get_message_type(self) -> str:
return "smessage"

View File

@@ -0,0 +1,14 @@
def convert_to_lower_and_upper_set(inputs: list[str] | set[str]) -> set[str]:
"""
Convert a list or set of strings to a set containing both lower and upper case versions of each string.
Args:
inputs (list[str] | set[str]): A list or set of strings to be converted.
Returns:
set[str]: A set containing both lower and upper case versions of each string.
"""
if not inputs:
return set()
else:
return {case for s in inputs if s for case in (s.lower(), s.upper())}

View File

@@ -0,0 +1,32 @@
"""Custom input types for Flask-RESTX request parsing."""
import re
def time_duration(value: str) -> str:
"""
Validate and return time duration string.
Accepts formats: <number>d (days), <number>h (hours), <number>m (minutes), <number>s (seconds)
Examples: 7d, 4h, 30m, 30s
Args:
value: The time duration string
Returns:
The validated time duration string
Raises:
ValueError: If the format is invalid
"""
if not value:
raise ValueError("Time duration cannot be empty")
pattern = r"^(\d+)([dhms])$"
if not re.match(pattern, value.lower()):
raise ValueError(
"Invalid time duration format. Use: <number>d (days), <number>h (hours), "
"<number>m (minutes), or <number>s (seconds). Examples: 7d, 4h, 30m, 30s"
)
return value.lower()

View File

@@ -0,0 +1,83 @@
import abc
import datetime
from typing import Protocol
import pytz
class _NowFunction(Protocol):
@abc.abstractmethod
def __call__(self, tz: datetime.timezone | None) -> datetime.datetime:
pass
# _now_func is a callable with the _NowFunction signature.
# Its sole purpose is to abstract time retrieval, enabling
# developers to mock this behavior in tests and time-dependent scenarios.
_now_func: _NowFunction = datetime.datetime.now
def naive_utc_now() -> datetime.datetime:
"""Return a naive datetime object (without timezone information)
representing current UTC time.
"""
return _now_func(datetime.UTC).replace(tzinfo=None)
def ensure_naive_utc(dt: datetime.datetime) -> datetime.datetime:
"""Return the datetime as naive UTC (tzinfo=None).
If the input is timezone-aware, convert to UTC and drop the tzinfo.
Assumes naive datetimes are already expressed in UTC.
"""
if dt.tzinfo is None:
return dt
return dt.astimezone(datetime.UTC).replace(tzinfo=None)
def parse_time_range(
start: str | None, end: str | None, tzname: str
) -> tuple[datetime.datetime | None, datetime.datetime | None]:
"""
Parse time range strings and convert to UTC datetime objects.
Handles DST ambiguity and non-existent times gracefully.
Args:
start: Start time string (YYYY-MM-DD HH:MM)
end: End time string (YYYY-MM-DD HH:MM)
tzname: Timezone name
Returns:
tuple: (start_datetime_utc, end_datetime_utc)
Raises:
ValueError: When time range is invalid or start > end
"""
tz = pytz.timezone(tzname)
utc = pytz.utc
def _parse(time_str: str | None, label: str) -> datetime.datetime | None:
if not time_str:
return None
try:
dt = datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M").replace(second=0)
except ValueError as e:
raise ValueError(f"Invalid {label} time format: {e}")
try:
return tz.localize(dt, is_dst=None).astimezone(utc)
except pytz.AmbiguousTimeError:
return tz.localize(dt, is_dst=False).astimezone(utc)
except pytz.NonExistentTimeError:
dt += datetime.timedelta(hours=1)
return tz.localize(dt, is_dst=None).astimezone(utc)
start_dt = _parse(start, "start")
end_dt = _parse(end, "end")
# Range validation
if start_dt and end_dt and start_dt > end_dt:
raise ValueError("start must be earlier than or equal to end")
return start_dt, end_dt

604
dify/api/libs/email_i18n.py Normal file
View File

@@ -0,0 +1,604 @@
"""
Email Internationalization Module
This module provides a centralized, elegant way to handle email internationalization
in Dify. It follows Domain-Driven Design principles with proper type hints and
eliminates the need for repetitive language switching logic.
"""
from dataclasses import dataclass
from enum import StrEnum, auto
from typing import Any, Protocol
from flask import render_template
from pydantic import BaseModel, Field
from extensions.ext_mail import mail
from services.feature_service import BrandingModel, FeatureService
class EmailType(StrEnum):
"""Enumeration of supported email types."""
RESET_PASSWORD = auto()
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST = auto()
INVITE_MEMBER = auto()
EMAIL_CODE_LOGIN = auto()
CHANGE_EMAIL_OLD = auto()
CHANGE_EMAIL_NEW = auto()
CHANGE_EMAIL_COMPLETED = auto()
OWNER_TRANSFER_CONFIRM = auto()
OWNER_TRANSFER_OLD_NOTIFY = auto()
OWNER_TRANSFER_NEW_NOTIFY = auto()
ACCOUNT_DELETION_SUCCESS = auto()
ACCOUNT_DELETION_VERIFICATION = auto()
ENTERPRISE_CUSTOM = auto()
QUEUE_MONITOR_ALERT = auto()
DOCUMENT_CLEAN_NOTIFY = auto()
EMAIL_REGISTER = auto()
EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto()
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto()
TRIGGER_EVENTS_LIMIT_SANDBOX = auto()
TRIGGER_EVENTS_LIMIT_PROFESSIONAL = auto()
TRIGGER_EVENTS_USAGE_WARNING_SANDBOX = auto()
TRIGGER_EVENTS_USAGE_WARNING_PROFESSIONAL = auto()
API_RATE_LIMIT_LIMIT_SANDBOX = auto()
API_RATE_LIMIT_WARNING_SANDBOX = auto()
class EmailLanguage(StrEnum):
"""Supported email languages with fallback handling."""
EN_US = "en-US"
ZH_HANS = "zh-Hans"
@classmethod
def from_language_code(cls, language_code: str) -> "EmailLanguage":
"""Convert a language code to EmailLanguage with fallback to English."""
if language_code == "zh-Hans":
return cls.ZH_HANS
return cls.EN_US
@dataclass(frozen=True)
class EmailTemplate:
"""Immutable value object representing an email template configuration."""
subject: str
template_path: str
branded_template_path: str
@dataclass(frozen=True)
class EmailContent:
"""Immutable value object containing rendered email content."""
subject: str
html_content: str
template_context: dict[str, Any]
class EmailI18nConfig(BaseModel):
"""Configuration for email internationalization."""
model_config = {"frozen": True, "extra": "forbid"}
templates: dict[EmailType, dict[EmailLanguage, EmailTemplate]] = Field(
default_factory=dict, description="Mapping of email types to language-specific templates"
)
def get_template(self, email_type: EmailType, language: EmailLanguage) -> EmailTemplate:
"""Get template configuration for specific email type and language."""
type_templates = self.templates.get(email_type)
if not type_templates:
raise ValueError(f"No templates configured for email type: {email_type}")
template = type_templates.get(language)
if not template:
# Fallback to English if specific language not found
template = type_templates.get(EmailLanguage.EN_US)
if not template:
raise ValueError(f"No template found for {email_type} in {language} or English")
return template
class EmailRenderer(Protocol):
"""Protocol for email template renderers."""
def render_template(self, template_path: str, **context: Any) -> str:
"""Render email template with given context."""
...
class FlaskEmailRenderer:
"""Flask-based email template renderer."""
def render_template(self, template_path: str, **context: Any) -> str:
"""Render email template using Flask's render_template."""
return render_template(template_path, **context)
class BrandingService(Protocol):
"""Protocol for branding service abstraction."""
def get_branding_config(self) -> BrandingModel:
"""Get current branding configuration."""
...
class FeatureBrandingService:
"""Feature service based branding implementation."""
def get_branding_config(self) -> BrandingModel:
"""Get branding configuration from feature service."""
return FeatureService.get_system_features().branding
class EmailSender(Protocol):
"""Protocol for email sending abstraction."""
def send_email(self, to: str, subject: str, html_content: str):
"""Send email with given parameters."""
...
class FlaskMailSender:
"""Flask-Mail based email sender."""
def send_email(self, to: str, subject: str, html_content: str):
"""Send email using Flask-Mail."""
if mail.is_inited():
mail.send(to=to, subject=subject, html=html_content)
class EmailI18nService:
"""
Main service for internationalized email handling.
This service provides a clean API for sending internationalized emails
with proper branding support and template management.
"""
def __init__(
self,
config: EmailI18nConfig,
renderer: EmailRenderer,
branding_service: BrandingService,
sender: EmailSender,
):
self._config = config
self._renderer = renderer
self._branding_service = branding_service
self._sender = sender
def send_email(
self,
email_type: EmailType,
language_code: str,
to: str,
template_context: dict[str, Any] | None = None,
):
"""
Send internationalized email with branding support.
Args:
email_type: Type of email to send
language_code: Target language code
to: Recipient email address
template_context: Additional context for template rendering
"""
if template_context is None:
template_context = {}
language = EmailLanguage.from_language_code(language_code)
email_content = self._render_email_content(email_type, language, template_context)
self._sender.send_email(to=to, subject=email_content.subject, html_content=email_content.html_content)
def send_change_email(
self,
language_code: str,
to: str,
code: str,
phase: str,
):
"""
Send change email notification with phase-specific handling.
Args:
language_code: Target language code
to: Recipient email address
code: Verification code
phase: Either 'old_email' or 'new_email'
"""
if phase == "old_email":
email_type = EmailType.CHANGE_EMAIL_OLD
elif phase == "new_email":
email_type = EmailType.CHANGE_EMAIL_NEW
else:
raise ValueError(f"Invalid phase: {phase}. Must be 'old_email' or 'new_email'")
self.send_email(
email_type=email_type,
language_code=language_code,
to=to,
template_context={
"to": to,
"code": code,
},
)
def send_raw_email(
self,
to: str | list[str],
subject: str,
html_content: str,
):
"""
Send a raw email directly without template processing.
This method is provided for backward compatibility with legacy email
sending that uses pre-rendered HTML content (e.g., enterprise emails
with custom templates).
Args:
to: Recipient email address(es)
subject: Email subject
html_content: Pre-rendered HTML content
"""
if isinstance(to, list):
for recipient in to:
self._sender.send_email(to=recipient, subject=subject, html_content=html_content)
else:
self._sender.send_email(to=to, subject=subject, html_content=html_content)
def _render_email_content(
self,
email_type: EmailType,
language: EmailLanguage,
template_context: dict[str, Any],
) -> EmailContent:
"""Render email content with branding and internationalization."""
template_config = self._config.get_template(email_type, language)
branding = self._branding_service.get_branding_config()
# Determine template path based on branding
template_path = template_config.branded_template_path if branding.enabled else template_config.template_path
# Prepare template context with branding information
full_context = {
**template_context,
"branding_enabled": branding.enabled,
"application_title": branding.application_title if branding.enabled else "Dify",
}
# Render template
html_content = self._renderer.render_template(template_path, **full_context)
# Apply templating to subject with all context variables
subject = template_config.subject
try:
subject = subject.format(**full_context)
except KeyError:
# If template variables are missing, fall back to basic formatting
if branding.enabled and "{application_title}" in subject:
subject = subject.format(application_title=branding.application_title)
return EmailContent(
subject=subject,
html_content=html_content,
template_context=full_context,
)
def create_default_email_config() -> EmailI18nConfig:
"""Create default email i18n configuration with all supported templates."""
templates: dict[EmailType, dict[EmailLanguage, EmailTemplate]] = {
EmailType.RESET_PASSWORD: {
EmailLanguage.EN_US: EmailTemplate(
subject="Set Your {application_title} Password",
template_path="reset_password_mail_template_en-US.html",
branded_template_path="without-brand/reset_password_mail_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="设置您的 {application_title} 密码",
template_path="reset_password_mail_template_zh-CN.html",
branded_template_path="without-brand/reset_password_mail_template_zh-CN.html",
),
},
EmailType.INVITE_MEMBER: {
EmailLanguage.EN_US: EmailTemplate(
subject="Join {application_title} Workspace Now",
template_path="invite_member_mail_template_en-US.html",
branded_template_path="without-brand/invite_member_mail_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="立即加入 {application_title} 工作空间",
template_path="invite_member_mail_template_zh-CN.html",
branded_template_path="without-brand/invite_member_mail_template_zh-CN.html",
),
},
EmailType.EMAIL_CODE_LOGIN: {
EmailLanguage.EN_US: EmailTemplate(
subject="{application_title} Login Code",
template_path="email_code_login_mail_template_en-US.html",
branded_template_path="without-brand/email_code_login_mail_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="{application_title} 登录验证码",
template_path="email_code_login_mail_template_zh-CN.html",
branded_template_path="without-brand/email_code_login_mail_template_zh-CN.html",
),
},
EmailType.CHANGE_EMAIL_OLD: {
EmailLanguage.EN_US: EmailTemplate(
subject="Check your current email",
template_path="change_mail_confirm_old_template_en-US.html",
branded_template_path="without-brand/change_mail_confirm_old_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="检测您现在的邮箱",
template_path="change_mail_confirm_old_template_zh-CN.html",
branded_template_path="without-brand/change_mail_confirm_old_template_zh-CN.html",
),
},
EmailType.CHANGE_EMAIL_NEW: {
EmailLanguage.EN_US: EmailTemplate(
subject="Confirm your new email address",
template_path="change_mail_confirm_new_template_en-US.html",
branded_template_path="without-brand/change_mail_confirm_new_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="确认您的邮箱地址变更",
template_path="change_mail_confirm_new_template_zh-CN.html",
branded_template_path="without-brand/change_mail_confirm_new_template_zh-CN.html",
),
},
EmailType.CHANGE_EMAIL_COMPLETED: {
EmailLanguage.EN_US: EmailTemplate(
subject="Your login email has been changed",
template_path="change_mail_completed_template_en-US.html",
branded_template_path="without-brand/change_mail_completed_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="您的登录邮箱已更改",
template_path="change_mail_completed_template_zh-CN.html",
branded_template_path="without-brand/change_mail_completed_template_zh-CN.html",
),
},
EmailType.OWNER_TRANSFER_CONFIRM: {
EmailLanguage.EN_US: EmailTemplate(
subject="Verify Your Request to Transfer Workspace Ownership",
template_path="transfer_workspace_owner_confirm_template_en-US.html",
branded_template_path="without-brand/transfer_workspace_owner_confirm_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="验证您转移工作空间所有权的请求",
template_path="transfer_workspace_owner_confirm_template_zh-CN.html",
branded_template_path="without-brand/transfer_workspace_owner_confirm_template_zh-CN.html",
),
},
EmailType.OWNER_TRANSFER_OLD_NOTIFY: {
EmailLanguage.EN_US: EmailTemplate(
subject="Workspace ownership has been transferred",
template_path="transfer_workspace_old_owner_notify_template_en-US.html",
branded_template_path="without-brand/transfer_workspace_old_owner_notify_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="工作区所有权已转移",
template_path="transfer_workspace_old_owner_notify_template_zh-CN.html",
branded_template_path="without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html",
),
},
EmailType.OWNER_TRANSFER_NEW_NOTIFY: {
EmailLanguage.EN_US: EmailTemplate(
subject="You are now the owner of {WorkspaceName}",
template_path="transfer_workspace_new_owner_notify_template_en-US.html",
branded_template_path="without-brand/transfer_workspace_new_owner_notify_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="您现在是 {WorkspaceName} 的所有者",
template_path="transfer_workspace_new_owner_notify_template_zh-CN.html",
branded_template_path="without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html",
),
},
EmailType.ACCOUNT_DELETION_SUCCESS: {
EmailLanguage.EN_US: EmailTemplate(
subject="Your Dify.AI Account Has Been Successfully Deleted",
template_path="delete_account_success_template_en-US.html",
branded_template_path="delete_account_success_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="您的 Dify.AI 账户已成功删除",
template_path="delete_account_success_template_zh-CN.html",
branded_template_path="delete_account_success_template_zh-CN.html",
),
},
EmailType.ACCOUNT_DELETION_VERIFICATION: {
EmailLanguage.EN_US: EmailTemplate(
subject="Dify.AI Account Deletion and Verification",
template_path="delete_account_code_email_template_en-US.html",
branded_template_path="delete_account_code_email_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="Dify.AI 账户删除和验证",
template_path="delete_account_code_email_template_zh-CN.html",
branded_template_path="delete_account_code_email_template_zh-CN.html",
),
},
EmailType.QUEUE_MONITOR_ALERT: {
EmailLanguage.EN_US: EmailTemplate(
subject="Alert: Dataset Queue pending tasks exceeded the limit",
template_path="queue_monitor_alert_email_template_en-US.html",
branded_template_path="queue_monitor_alert_email_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="警报:数据集队列待处理任务超过限制",
template_path="queue_monitor_alert_email_template_zh-CN.html",
branded_template_path="queue_monitor_alert_email_template_zh-CN.html",
),
},
EmailType.DOCUMENT_CLEAN_NOTIFY: {
EmailLanguage.EN_US: EmailTemplate(
subject="Dify Knowledge base auto disable notification",
template_path="clean_document_job_mail_template-US.html",
branded_template_path="clean_document_job_mail_template-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="Dify 知识库自动禁用通知",
template_path="clean_document_job_mail_template_zh-CN.html",
branded_template_path="clean_document_job_mail_template_zh-CN.html",
),
},
EmailType.TRIGGER_EVENTS_LIMIT_SANDBOX: {
EmailLanguage.EN_US: EmailTemplate(
subject="Youve reached your Sandbox Trigger Events limit",
template_path="trigger_events_limit_template_en-US.html",
branded_template_path="without-brand/trigger_events_limit_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="您的 Sandbox 触发事件额度已用尽",
template_path="trigger_events_limit_template_zh-CN.html",
branded_template_path="without-brand/trigger_events_limit_template_zh-CN.html",
),
},
EmailType.TRIGGER_EVENTS_LIMIT_PROFESSIONAL: {
EmailLanguage.EN_US: EmailTemplate(
subject="Youve reached your monthly Trigger Events limit",
template_path="trigger_events_limit_template_en-US.html",
branded_template_path="without-brand/trigger_events_limit_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="您的月度触发事件额度已用尽",
template_path="trigger_events_limit_template_zh-CN.html",
branded_template_path="without-brand/trigger_events_limit_template_zh-CN.html",
),
},
EmailType.TRIGGER_EVENTS_USAGE_WARNING_SANDBOX: {
EmailLanguage.EN_US: EmailTemplate(
subject="Youre nearing your Sandbox Trigger Events limit",
template_path="trigger_events_usage_warning_template_en-US.html",
branded_template_path="without-brand/trigger_events_usage_warning_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="您的 Sandbox 触发事件额度接近上限",
template_path="trigger_events_usage_warning_template_zh-CN.html",
branded_template_path="without-brand/trigger_events_usage_warning_template_zh-CN.html",
),
},
EmailType.TRIGGER_EVENTS_USAGE_WARNING_PROFESSIONAL: {
EmailLanguage.EN_US: EmailTemplate(
subject="Youre nearing your Monthly Trigger Events limit",
template_path="trigger_events_usage_warning_template_en-US.html",
branded_template_path="without-brand/trigger_events_usage_warning_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="您的月度触发事件额度接近上限",
template_path="trigger_events_usage_warning_template_zh-CN.html",
branded_template_path="without-brand/trigger_events_usage_warning_template_zh-CN.html",
),
},
EmailType.API_RATE_LIMIT_LIMIT_SANDBOX: {
EmailLanguage.EN_US: EmailTemplate(
subject="Youve reached your API Rate Limit",
template_path="api_rate_limit_limit_template_en-US.html",
branded_template_path="without-brand/api_rate_limit_limit_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="您的 API 速率额度已用尽",
template_path="api_rate_limit_limit_template_zh-CN.html",
branded_template_path="without-brand/api_rate_limit_limit_template_zh-CN.html",
),
},
EmailType.API_RATE_LIMIT_WARNING_SANDBOX: {
EmailLanguage.EN_US: EmailTemplate(
subject="Youre nearing your API Rate Limit",
template_path="api_rate_limit_warning_template_en-US.html",
branded_template_path="without-brand/api_rate_limit_warning_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="您的 API 速率额度接近上限",
template_path="api_rate_limit_warning_template_zh-CN.html",
branded_template_path="without-brand/api_rate_limit_warning_template_zh-CN.html",
),
},
EmailType.EMAIL_REGISTER: {
EmailLanguage.EN_US: EmailTemplate(
subject="Register Your {application_title} Account",
template_path="register_email_template_en-US.html",
branded_template_path="without-brand/register_email_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="注册您的 {application_title} 账户",
template_path="register_email_template_zh-CN.html",
branded_template_path="without-brand/register_email_template_zh-CN.html",
),
},
EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST: {
EmailLanguage.EN_US: EmailTemplate(
subject="Register Your {application_title} Account",
template_path="register_email_when_account_exist_template_en-US.html",
branded_template_path="without-brand/register_email_when_account_exist_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="注册您的 {application_title} 账户",
template_path="register_email_when_account_exist_template_zh-CN.html",
branded_template_path="without-brand/register_email_when_account_exist_template_zh-CN.html",
),
},
EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST: {
EmailLanguage.EN_US: EmailTemplate(
subject="Reset Your {application_title} Password",
template_path="reset_password_mail_when_account_not_exist_template_en-US.html",
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="重置您的 {application_title} 密码",
template_path="reset_password_mail_when_account_not_exist_template_zh-CN.html",
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html",
),
},
EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER: {
EmailLanguage.EN_US: EmailTemplate(
subject="Reset Your {application_title} Password",
template_path="reset_password_mail_when_account_not_exist_no_register_template_en-US.html",
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="重置您的 {application_title} 密码",
template_path="reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html",
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html",
),
},
}
return EmailI18nConfig(templates=templates)
# Singleton instance for application-wide use
def get_default_email_i18n_service() -> EmailI18nService:
"""Get configured email i18n service with default dependencies."""
config = create_default_email_config()
renderer = FlaskEmailRenderer()
branding_service = FeatureBrandingService()
sender = FlaskMailSender()
return EmailI18nService(
config=config,
renderer=renderer,
branding_service=branding_service,
sender=sender,
)
# Global instance
_email_i18n_service: EmailI18nService | None = None
def get_email_i18n_service() -> EmailI18nService:
"""Get global email i18n service instance."""
global _email_i18n_service
if _email_i18n_service is None:
_email_i18n_service = get_default_email_i18n_service()
return _email_i18n_service

View File

@@ -0,0 +1,15 @@
from werkzeug.exceptions import HTTPException
class BaseHTTPException(HTTPException):
error_code: str = "unknown"
data: dict | None = None
def __init__(self, description=None, response=None):
super().__init__(description, response)
self.data = {
"code": self.error_code,
"message": self.description,
"status": self.code,
}

View File

@@ -0,0 +1,142 @@
import re
import sys
from collections.abc import Mapping
from typing import Any
from flask import Blueprint, Flask, current_app, got_request_exception
from flask_restx import Api
from werkzeug.exceptions import HTTPException
from werkzeug.http import HTTP_STATUS_CODES
from configs import dify_config
from core.errors.error import AppInvokeQuotaExceededError
from libs.token import build_force_logout_cookie_headers
def http_status_message(code):
return HTTP_STATUS_CODES.get(code, "")
def register_external_error_handlers(api: Api):
@api.errorhandler(HTTPException)
def handle_http_exception(e: HTTPException):
got_request_exception.send(current_app, exception=e)
# If Werkzeug already prepared a Response, just use it.
if e.response is not None:
return e.response
status_code = getattr(e, "code", 500) or 500
# Build a safe, dict-like payload
default_data = {
"code": re.sub(r"(?<!^)(?=[A-Z])", "_", type(e).__name__).lower(),
"message": getattr(e, "description", http_status_message(status_code)),
"status": status_code,
}
if default_data["message"] == "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)":
default_data["message"] = "Invalid JSON payload received or JSON payload is empty."
# Use headers on the exception if present; otherwise none.
headers = {}
exc_headers = getattr(e, "headers", None)
if exc_headers:
headers.update(exc_headers)
# Payload per status
if status_code == 406 and api.default_mediatype is None:
data = {"code": "not_acceptable", "message": default_data["message"], "status": status_code}
return data, status_code, headers
elif status_code == 400:
msg = default_data["message"]
if isinstance(msg, Mapping) and msg:
# Convert param errors like {"field": "reason"} into a friendly shape
param_key, param_value = next(iter(msg.items()))
data = {
"code": "invalid_param",
"message": str(param_value),
"params": param_key,
"status": status_code,
}
else:
data = {**default_data}
data.setdefault("code", "unknown")
return data, status_code, headers
else:
data = {**default_data}
data.setdefault("code", "unknown")
# If you need WWW-Authenticate for 401, add it to headers
if status_code == 401:
headers["WWW-Authenticate"] = 'Bearer realm="api"'
# Check if this is a forced logout error - clear cookies
error_code = getattr(e, "error_code", None)
if error_code == "unauthorized_and_force_logout":
# Add Set-Cookie headers to clear auth cookies
headers["Set-Cookie"] = build_force_logout_cookie_headers()
return data, status_code, headers
_ = handle_http_exception
@api.errorhandler(ValueError)
def handle_value_error(e: ValueError):
got_request_exception.send(current_app, exception=e)
status_code = 400
data = {"code": "invalid_param", "message": str(e), "status": status_code}
return data, status_code
_ = handle_value_error
@api.errorhandler(AppInvokeQuotaExceededError)
def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
got_request_exception.send(current_app, exception=e)
status_code = 429
data = {"code": "too_many_requests", "message": str(e), "status": status_code}
return data, status_code
_ = handle_quota_exceeded
@api.errorhandler(Exception)
def handle_general_exception(e: Exception):
got_request_exception.send(current_app, exception=e)
status_code = 500
data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)})
# 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response)
if not isinstance(data, dict):
data = {"message": str(e)}
data.setdefault("code", "unknown")
data.setdefault("status", status_code)
# Log stack
exc_info: Any = sys.exc_info()
if exc_info[1] is None:
exc_info = (None, None, None)
current_app.log_exception(exc_info)
return data, status_code
_ = handle_general_exception
class ExternalApi(Api):
_authorizations = {
"Bearer": {
"type": "apiKey",
"in": "header",
"name": "Authorization",
"description": "Type: Bearer {your-api-key}",
}
}
def __init__(self, app: Blueprint | Flask, *args, **kwargs):
kwargs.setdefault("authorizations", self._authorizations)
kwargs.setdefault("security", "Bearer")
kwargs["add_specs"] = dify_config.SWAGGER_UI_ENABLED
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
# manual separate call on construction and init_app to ensure configs in kwargs effective
super().__init__(app=None, *args, **kwargs)
self.init_app(app, **kwargs)
register_external_error_handlers(self)

View File

@@ -0,0 +1,30 @@
from pathlib import Path
def search_file_upwards(
base_dir_path: Path,
target_file_name: str,
max_search_parent_depth: int,
) -> Path:
"""
Find a target file in the current directory or its parent directories up to a specified depth.
:param base_dir_path: Starting directory path to search from.
:param target_file_name: Name of the file to search for.
:param max_search_parent_depth: Maximum number of parent directories to search upwards.
:return: Path of the file if found, otherwise None.
"""
current_path = base_dir_path.resolve()
for _ in range(max_search_parent_depth):
candidate_path = current_path / target_file_name
if candidate_path.is_file():
return candidate_path
parent_path = current_path.parent
if parent_path == current_path: # reached the root directory
break
else:
current_path = parent_path
raise ValueError(
f"File '{target_file_name}' not found in the directory '{base_dir_path.resolve()}' or its parent directories"
f" in depth of {max_search_parent_depth}."
)

View File

@@ -0,0 +1,66 @@
import contextvars
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TypeVar
from flask import Flask, g
T = TypeVar("T")
@contextmanager
def preserve_flask_contexts(
flask_app: Flask,
context_vars: contextvars.Context,
) -> Iterator[None]:
"""
A context manager that handles:
1. flask-login's UserProxy copy
2. ContextVars copy
3. flask_app.app_context()
This context manager ensures that the Flask application context is properly set up,
the current user is preserved across context boundaries, and any provided context variables
are set within the new context.
Note:
This manager aims to allow use current_user cross thread and app context,
but it's not the recommend use, it's better to pass user directly in parameters.
Args:
flask_app: The Flask application instance
context_vars: contextvars.Context object containing context variables to be set in the new context
Yields:
None
Example:
```python
with preserve_flask_contexts(flask_app, context_vars=context_vars):
# Code that needs Flask app context and context variables
# Current user will be preserved if available
```
"""
# Set context variables if provided
if context_vars:
for var, val in context_vars.items():
var.set(val)
# Save current user before entering new app context
saved_user = None
# Check for user in g (works in both request context and app context)
if hasattr(g, "_login_user"):
saved_user = g._login_user
# Enter Flask app context
with flask_app.app_context():
try:
# Restore user in new app context if it was saved
if saved_user is not None:
g._login_user = saved_user
# Yield control back to the caller
yield
finally:
# Any cleanup can be added here if needed
pass

View File

@@ -0,0 +1,241 @@
#
# Cipher/PKCS1_OAEP.py : PKCS#1 OAEP
#
# ===================================================================
# The contents of this file are dedicated to the public domain. To
# the extent that dedication to the public domain is not available,
# everyone is granted a worldwide, perpetual, royalty-free,
# non-exclusive license to exercise all rights associated with the
# contents of this file for any purpose whatsoever.
# No rights are reserved.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# ===================================================================
from hashlib import sha1
import Crypto.Hash.SHA1
import Crypto.Util.number
import gmpy2
from Crypto import Random
from Crypto.Signature.pss import MGF1
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
from Crypto.Util.py3compat import bord
from Crypto.Util.strxor import strxor
class PKCS1OAepCipher:
"""Cipher object for PKCS#1 v1.5 OAEP.
Do not create directly: use :func:`new` instead."""
def __init__(self, key, hashAlgo, mgfunc, label, randfunc):
"""Initialize this PKCS#1 OAEP cipher object.
:Parameters:
key : an RSA key object
If a private half is given, both encryption and decryption are possible.
If a public half is given, only encryption is possible.
hashAlgo : hash object
The hash function to use. This can be a module under `Crypto.Hash`
or an existing hash object created from any of such modules. If not specified,
`Crypto.Hash.SHA1` is used.
mgfunc : callable
A mask generation function that accepts two parameters: a string to
use as seed, and the length of the mask to generate, in bytes.
If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice).
label : bytes/bytearray/memoryview
A label to apply to this particular encryption. If not specified,
an empty string is used. Specifying a label does not improve
security.
randfunc : callable
A function that returns random bytes.
:attention: Modify the mask generation function only if you know what you are doing.
Sender and receiver must use the same one.
"""
self._key = key
if hashAlgo:
self._hashObj = hashAlgo
else:
self._hashObj = Crypto.Hash.SHA1
if mgfunc:
self._mgf = mgfunc
else:
self._mgf = lambda x, y: MGF1(x, y, self._hashObj)
self._label = bytes(label)
self._randfunc = randfunc
def can_encrypt(self):
"""Legacy function to check if you can call :meth:`encrypt`.
.. deprecated:: 3.0"""
return self._key.can_encrypt()
def can_decrypt(self):
"""Legacy function to check if you can call :meth:`decrypt`.
.. deprecated:: 3.0"""
return self._key.can_decrypt()
def encrypt(self, message):
"""Encrypt a message with PKCS#1 OAEP.
:param message:
The message to encrypt, also known as plaintext. It can be of
variable length, but not longer than the RSA modulus (in bytes)
minus 2, minus twice the hash output size.
For instance, if you use RSA 2048 and SHA-256, the longest message
you can encrypt is 190 byte long.
:type message: bytes/bytearray/memoryview
:returns: The ciphertext, as large as the RSA modulus.
:rtype: bytes
:raises ValueError:
if the message is too long.
"""
# See 7.1.1 in RFC3447
modBits = Crypto.Util.number.size(self._key.n)
k = ceil_div(modBits, 8) # Convert from bits to bytes
hLen = self._hashObj.digest_size
mLen = len(message)
# Step 1b
ps_len = k - mLen - 2 * hLen - 2
if ps_len < 0:
raise ValueError("Plaintext is too long.")
# Step 2a
lHash = sha1(self._label).digest()
# Step 2b
ps = b"\x00" * ps_len
# Step 2c
db = lHash + ps + b"\x01" + bytes(message)
# Step 2d
ros = self._randfunc(hLen)
# Step 2e
dbMask = self._mgf(ros, k - hLen - 1)
# Step 2f
maskedDB = strxor(db, dbMask)
# Step 2g
seedMask = self._mgf(maskedDB, hLen)
# Step 2h
maskedSeed = strxor(ros, seedMask)
# Step 2i
em = b"\x00" + maskedSeed + maskedDB
# Step 3a (OS2IP)
em_int = bytes_to_long(em)
# Step 3b (RSAEP)
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
# Step 3c (I2OSP)
c = long_to_bytes(m_int, k)
return c
def decrypt(self, ciphertext):
"""Decrypt a message with PKCS#1 OAEP.
:param ciphertext: The encrypted message.
:type ciphertext: bytes/bytearray/memoryview
:returns: The original message (plaintext).
:rtype: bytes
:raises ValueError:
if the ciphertext has the wrong length, or if decryption
fails the integrity check (in which case, the decryption
key is probably wrong).
:raises TypeError:
if the RSA key has no private half (i.e. you are trying
to decrypt using a public key).
"""
# See 7.1.2 in RFC3447
modBits = Crypto.Util.number.size(self._key.n)
k = ceil_div(modBits, 8) # Convert from bits to bytes
hLen = self._hashObj.digest_size
# Step 1b and 1c
if len(ciphertext) != k or k < hLen + 2:
raise ValueError("Ciphertext with incorrect length.")
# Step 2a (O2SIP)
ct_int = bytes_to_long(ciphertext)
# Step 2b (RSADP)
# m_int = self._key._decrypt(ct_int)
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
# Complete step 2c (I2OSP)
em = long_to_bytes(m_int, k)
# Step 3a
lHash = sha1(self._label).digest()
# Step 3b
y = em[0]
# y must be 0, but we MUST NOT check it here in order not to
# allow attacks like Manger's (http://dl.acm.org/citation.cfm?id=704143)
maskedSeed = em[1 : hLen + 1]
maskedDB = em[hLen + 1 :]
# Step 3c
seedMask = self._mgf(maskedDB, hLen)
# Step 3d
seed = strxor(maskedSeed, seedMask)
# Step 3e
dbMask = self._mgf(seed, k - hLen - 1)
# Step 3f
db = strxor(maskedDB, dbMask)
# Step 3g
one_pos = hLen + db[hLen:].find(b"\x01")
lHash1 = db[:hLen]
invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type]
hash_compare = strxor(lHash1, lHash)
for x in hash_compare:
invalid |= bord(x) # type: ignore[arg-type]
for x in db[hLen:one_pos]:
invalid |= bord(x) # type: ignore[arg-type]
if invalid != 0:
raise ValueError("Incorrect decryption.")
# Step 4
return db[one_pos + 1 :]
def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None):
"""Return a cipher object :class:`PKCS1OAEP_Cipher`
that can be used to perform PKCS#1 OAEP encryption or decryption.
:param key:
The key object to use to encrypt or decrypt the message.
Decryption is only possible with a private RSA key.
:type key: RSA key object
:param hashAlgo:
The hash function to use. This can be a module under `Crypto.Hash`
or an existing hash object created from any of such modules.
If not specified, `Crypto.Hash.SHA1` is used.
:type hashAlgo: hash object
:param mgfunc:
A mask generation function that accepts two parameters: a string to
use as seed, and the length of the mask to generate, in bytes.
If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice).
:type mgfunc: callable
:param label:
A label to apply to this particular encryption. If not specified,
an empty string is used. Specifying a label does not improve
security.
:type label: bytes/bytearray/memoryview
:param randfunc:
A function that returns random bytes.
The default is `Random.get_random_bytes`.
:type randfunc: callable
"""
if randfunc is None:
randfunc = Random.get_random_bytes
return PKCS1OAepCipher(key, hashAlgo, mgfunc, label, randfunc)

380
dify/api/libs/helper.py Normal file
View File

@@ -0,0 +1,380 @@
import json
import logging
import re
import secrets
import string
import struct
import subprocess
import time
import uuid
from collections.abc import Generator, Mapping
from datetime import datetime
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from zoneinfo import available_timezones
from flask import Response, stream_with_context
from flask_restx import fields
from pydantic import BaseModel
from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.file import helpers as file_helpers
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_redis import redis_client
if TYPE_CHECKING:
from models import Account
from models.model import EndUser
logger = logging.getLogger(__name__)
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
"""
Extract tenant_id from Account or EndUser object.
Args:
user: Account or EndUser object
Returns:
tenant_id string if available, None otherwise
Raises:
ValueError: If user is neither Account nor EndUser
"""
from models import Account
from models.model import EndUser
if isinstance(user, Account):
return user.current_tenant_id
elif isinstance(user, EndUser):
return user.tenant_id
else:
raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.")
def run(script):
return subprocess.getstatusoutput("source /root/.bashrc && " + script)
class AppIconUrlField(fields.Raw):
def output(self, key, obj, **kwargs):
if obj is None:
return None
from models.model import App, IconType, Site
if isinstance(obj, dict) and "app" in obj:
obj = obj["app"]
if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE:
return file_helpers.get_signed_file_url(obj.icon)
return None
class AvatarUrlField(fields.Raw):
def output(self, key, obj, **kwargs):
if obj is None:
return None
from models import Account
if isinstance(obj, Account) and obj.avatar is not None:
if obj.avatar.startswith(("http://", "https://")):
return obj.avatar
return file_helpers.get_signed_file_url(obj.avatar)
return None
class TimestampField(fields.Raw):
def format(self, value) -> int:
return int(value.timestamp())
def email(email):
# Define a regex pattern for email addresses
pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$"
# Check if the email matches the pattern
if re.match(pattern, email) is not None:
return email
error = f"{email} is not a valid email."
raise ValueError(error)
def uuid_value(value):
if value == "":
return str(value)
try:
uuid_obj = uuid.UUID(value)
return str(uuid_obj)
except ValueError:
error = f"{value} is not a valid uuid."
raise ValueError(error)
def alphanumeric(value: str):
# check if the value is alphanumeric and underlined
if re.match(r"^[a-zA-Z0-9_]+$", value):
return value
raise ValueError(f"{value} is not a valid alphanumeric value")
def timestamp_value(timestamp):
try:
int_timestamp = int(timestamp)
if int_timestamp < 0:
raise ValueError
return int_timestamp
except ValueError:
error = f"{timestamp} is not a valid timestamp."
raise ValueError(error)
class StrLen:
"""Restrict input to an integer in a range (inclusive)"""
def __init__(self, max_length, argument="argument"):
self.max_length = max_length
self.argument = argument
def __call__(self, value):
length = len(value)
if length > self.max_length:
error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format(
arg=self.argument, val=value, length=self.max_length
)
raise ValueError(error)
return value
class DatetimeString:
def __init__(self, format, argument="argument"):
self.format = format
self.argument = argument
def __call__(self, value):
try:
datetime.strptime(value, self.format)
except ValueError:
error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format(
arg=self.argument, val=value, format=self.format
)
raise ValueError(error)
return value
def timezone(timezone_string):
if timezone_string and timezone_string in available_timezones():
return timezone_string
error = f"{timezone_string} is not a valid timezone."
raise ValueError(error)
def convert_datetime_to_date(field, target_timezone: str = ":tz"):
if dify_config.DB_TYPE == "postgresql":
return f"DATE(DATE_TRUNC('day', {field} AT TIME ZONE 'UTC' AT TIME ZONE {target_timezone}))"
elif dify_config.DB_TYPE == "mysql":
return f"DATE(CONVERT_TZ({field}, 'UTC', {target_timezone}))"
else:
raise NotImplementedError(f"Unsupported database type: {dify_config.DB_TYPE}")
def generate_string(n):
letters_digits = string.ascii_letters + string.digits
result = ""
for _ in range(n):
result += secrets.choice(letters_digits)
return result
def extract_remote_ip(request) -> str:
if request.headers.get("CF-Connecting-IP"):
return cast(str, request.headers.get("CF-Connecting-IP"))
elif request.headers.getlist("X-Forwarded-For"):
return cast(str, request.headers.getlist("X-Forwarded-For")[0])
else:
return cast(str, request.remote_addr)
def generate_text_hash(text: str) -> str:
hash_text = str(text) + "None"
return sha256(hash_text.encode()).hexdigest()
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json")
else:
def generate() -> Generator:
yield from response
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
"""
This function is used to return a response with a length prefix.
Magic number is a one byte number that indicates the type of the response.
For a compatibility with latest plugin daemon https://github.com/langgenius/dify-plugin-daemon/pull/341
Avoid using line-based response, it leads a memory issue.
We uses following format:
| Field | Size | Description |
|---------------|----------|---------------------------------|
| Magic Number | 1 byte | Magic number identifier |
| Reserved | 1 byte | Reserved field |
| Header Length | 2 bytes | Header length (usually 0xa) |
| Data Length | 4 bytes | Length of the data |
| Reserved | 6 bytes | Reserved fields |
| Data | Variable | Actual data content |
| Reserved Fields | Header | Data |
|-----------------|----------|----------|
| 4 bytes total | Variable | Variable |
all data is in little endian
"""
def pack_response_with_length_prefix(response: bytes) -> bytes:
header_length = 0xA
data_length = len(response)
# | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
if isinstance(response, dict):
return Response(
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
status=200,
mimetype="application/json",
)
elif isinstance(response, BaseModel):
return Response(
response=pack_response_with_length_prefix(response.model_dump_json().encode("utf-8")),
status=200,
mimetype="application/json",
)
def generate() -> Generator:
for chunk in response:
if isinstance(chunk, str):
yield pack_response_with_length_prefix(chunk.encode("utf-8"))
else:
yield pack_response_with_length_prefix(chunk)
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
class TokenManager:
@classmethod
def generate_token(
cls,
token_type: str,
account: Optional["Account"] = None,
email: str | None = None,
additional_data: dict | None = None,
) -> str:
if account is None and email is None:
raise ValueError("Account or email must be provided")
account_id = account.id if account else None
account_email = account.email if account else email
if account_id:
old_token = cls._get_current_token_for_account(account_id, token_type)
if old_token:
if isinstance(old_token, bytes):
old_token = old_token.decode("utf-8")
cls.revoke_token(old_token, token_type)
token = str(uuid.uuid4())
token_data = {"account_id": account_id, "email": account_email, "token_type": token_type}
if additional_data:
token_data.update(additional_data)
expiry_minutes = dify_config.model_dump().get(f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES")
if expiry_minutes is None:
raise ValueError(f"Expiry minutes for {token_type} token is not set")
token_key = cls._get_token_key(token, token_type)
expiry_seconds = int(expiry_minutes * 60)
redis_client.setex(token_key, expiry_seconds, json.dumps(token_data))
if account_id:
cls._set_current_token_for_account(account_id, token, token_type, expiry_minutes)
return token
@classmethod
def _get_token_key(cls, token: str, token_type: str) -> str:
return f"{token_type}:token:{token}"
@classmethod
def revoke_token(cls, token: str, token_type: str):
token_key = cls._get_token_key(token, token_type)
redis_client.delete(token_key)
@classmethod
def get_token_data(cls, token: str, token_type: str) -> dict[str, Any] | None:
key = cls._get_token_key(token, token_type)
token_data_json = redis_client.get(key)
if token_data_json is None:
logger.warning("%s token %s not found with key %s", token_type, token, key)
return None
token_data: dict[str, Any] | None = json.loads(token_data_json)
return token_data
@classmethod
def _get_current_token_for_account(cls, account_id: str, token_type: str) -> str | None:
key = cls._get_account_token_key(account_id, token_type)
current_token: str | None = redis_client.get(key)
return current_token
@classmethod
def _set_current_token_for_account(
cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float]
):
key = cls._get_account_token_key(account_id, token_type)
expiry_seconds = int(expiry_minutes * 60)
redis_client.setex(key, expiry_seconds, token)
@classmethod
def _get_account_token_key(cls, account_id: str, token_type: str) -> str:
return f"{token_type}:account:{account_id}"
class RateLimiter:
def __init__(self, prefix: str, max_attempts: int, time_window: int):
self.prefix = prefix
self.max_attempts = max_attempts
self.time_window = time_window
def _get_key(self, email: str) -> str:
return f"{self.prefix}:{email}"
def is_rate_limited(self, email: str) -> bool:
key = self._get_key(email)
current_time = int(time.time())
window_start_time = current_time - self.time_window
redis_client.zremrangebyscore(key, "-inf", window_start_time)
attempts = redis_client.zcard(key)
if attempts and int(attempts) >= self.max_attempts:
return True
return False
def increment_rate_limit(self, email: str):
key = self._get_key(email)
current_time = int(time.time())
redis_client.zadd(key, {current_time: current_time})
redis_client.expire(key, self.time_window * 2)

View File

@@ -0,0 +1,5 @@
class InfiniteScrollPagination:
def __init__(self, data, limit, has_more):
self.data = data
self.limit = limit
self.has_more = has_more

View File

@@ -0,0 +1,52 @@
import json
from core.llm_generator.output_parser.errors import OutputParserError
def parse_json_markdown(json_string: str):
# Get json from the backticks/braces
json_string = json_string.strip()
starts = ["```json", "```", "``", "`", "{", "["]
ends = ["```", "``", "`", "}", "]"]
end_index = -1
start_index = 0
parsed: dict = {}
for s in starts:
start_index = json_string.find(s)
if start_index != -1:
if json_string[start_index] not in ("{", "["):
start_index += len(s)
break
if start_index != -1:
for e in ends:
end_index = json_string.rfind(e, start_index)
if end_index != -1:
if json_string[end_index] in ("}", "]"):
end_index += 1
break
if start_index != -1 and end_index != -1 and start_index < end_index:
extracted_content = json_string[start_index:end_index].strip()
parsed = json.loads(extracted_content)
else:
raise ValueError("could not find json block in the output.")
return parsed
def parse_and_check_json_markdown(text: str, expected_keys: list[str]):
try:
json_obj = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise OutputParserError(f"got invalid json object. error: {e}")
if isinstance(json_obj, list):
if len(json_obj) == 1 and isinstance(json_obj[0], dict):
json_obj = json_obj[0]
else:
raise OutputParserError(f"got invalid return object. obj:{json_obj}")
for key in expected_keys:
if key not in json_obj:
raise OutputParserError(
f"got invalid return object. expected key `{key}` to be present, but got {json_obj}"
)
return json_obj

98
dify/api/libs/login.py Normal file
View File

@@ -0,0 +1,98 @@
from collections.abc import Callable
from functools import wraps
from typing import Any
from flask import current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS
from werkzeug.local import LocalProxy
from configs import dify_config
from libs.token import check_csrf_token
from models import Account
from models.model import EndUser
def current_account_with_tenant():
"""
Resolve the underlying account for the current user proxy and ensure tenant context exists.
Allows tests to supply plain Account mocks without the LocalProxy helper.
"""
user_proxy = current_user
get_current_object = getattr(user_proxy, "_get_current_object", None)
user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore
if not isinstance(user, Account):
raise ValueError("current_user must be an Account instance")
assert user.current_tenant_id is not None, "The tenant information should be loaded."
return user, user.current_tenant_id
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
def login_required(func: Callable[P, R]):
"""
If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are
not, it calls the :attr:`LoginManager.unauthorized` callback.) For
example::
@app.route('/post')
@login_required
def post():
pass
If there are only certain times you need to require that your user is
logged in, you can do so with::
if not current_user.is_authenticated:
return current_app.login_manager.unauthorized()
...which is essentially the code that this function adds to your views.
It can be convenient to globally turn off authentication when unit testing.
To enable this, if the application configuration variable `LOGIN_DISABLED`
is set to `True`, this decorator will be ignored.
.. Note ::
Per `W3 guidelines for CORS preflight requests
<http://www.w3.org/TR/cors/#cross-origin-request-with-preflight-0>`_,
HTTP ``OPTIONS`` requests are exempt from login checks.
:param func: The view function to decorate.
:type func: function
"""
@wraps(func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
pass
elif current_user is not None and not current_user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore
# we put csrf validation here for less conflicts
# TODO: maybe find a better place for it.
check_csrf_token(request, current_user.id)
return current_app.ensure_sync(func)(*args, **kwargs)
return decorated_view
def _get_user() -> EndUser | Account | None:
if has_request_context():
if "_login_user" not in g:
current_app.login_manager._load_user() # type: ignore
return g._login_user
return None
#: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user
# NOTE: Any here, but use _get_current_object to check the fields
current_user: Any = LocalProxy(lambda: _get_user())

View File

@@ -0,0 +1,54 @@
"""
Module loading utilities similar to Django's module_loading.
Reference implementation from Django:
https://github.com/django/django/blob/main/django/utils/module_loading.py
"""
import sys
from importlib import import_module
def cached_import(module_path: str, class_name: str):
"""
Import a module and return the named attribute/class from it, with caching.
Args:
module_path: The module path to import from
class_name: The attribute/class name to retrieve
Returns:
The imported attribute/class
"""
if not (
(module := sys.modules.get(module_path))
and (spec := getattr(module, "__spec__", None))
and getattr(spec, "_initializing", False) is False
):
module = import_module(module_path)
return getattr(module, class_name)
def import_string(dotted_path: str):
"""
Import a dotted module path and return the attribute/class designated by
the last name in the path. Raise ImportError if the import failed.
Args:
dotted_path: Full module path to the class (e.g., 'module.submodule.ClassName')
Returns:
The imported class or attribute
Raises:
ImportError: If the module or attribute cannot be imported
"""
try:
module_path, class_name = dotted_path.rsplit(".", 1)
except ValueError as err:
raise ImportError(f"{dotted_path} doesn't look like a module path") from err
try:
return cached_import(module_path, class_name)
except AttributeError as err:
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class') from err

132
dify/api/libs/oauth.py Normal file
View File

@@ -0,0 +1,132 @@
import urllib.parse
from dataclasses import dataclass
import httpx
@dataclass
class OAuthUserInfo:
id: str
name: str
email: str
class OAuth:
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self):
raise NotImplementedError()
def get_access_token(self, code: str):
raise NotImplementedError()
def get_raw_user_info(self, token: str):
raise NotImplementedError()
def get_user_info(self, token: str) -> OAuthUserInfo:
raw_info = self.get_raw_user_info(token)
return self._transform_user_info(raw_info)
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
raise NotImplementedError()
class GitHubOAuth(OAuth):
_AUTH_URL = "https://github.com/login/oauth/authorize"
_TOKEN_URL = "https://github.com/login/oauth/access_token"
_USER_INFO_URL = "https://api.github.com/user"
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
def get_authorization_url(self, invite_token: str | None = None):
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": "user:email", # Request only basic user information
}
if invite_token:
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"redirect_uri": self.redirect_uri,
}
headers = {"Accept": "application/json"}
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
if not access_token:
raise ValueError(f"Error in GitHub OAuth: {response_json}")
return access_token
def get_raw_user_info(self, token: str):
headers = {"Authorization": f"token {token}"}
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
user_info = response.json()
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_info = email_response.json()
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
return {**user_info, "email": primary_email.get("email", "")}
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
email = raw_info.get("email")
if not email:
email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email)
class GoogleOAuth(OAuth):
_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
_TOKEN_URL = "https://oauth2.googleapis.com/token"
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
def get_authorization_url(self, invite_token: str | None = None):
params = {
"client_id": self.client_id,
"response_type": "code",
"redirect_uri": self.redirect_uri,
"scope": "openid email",
}
if invite_token:
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": self.redirect_uri,
}
headers = {"Accept": "application/json"}
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
if not access_token:
raise ValueError(f"Error in Google OAuth: {response_json}")
return access_token
def get_raw_user_info(self, token: str):
headers = {"Authorization": f"Bearer {token}"}
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
return response.json()
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])

View File

@@ -0,0 +1,305 @@
import urllib.parse
from typing import Any
import httpx
from flask_login import current_user
from sqlalchemy import select
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
class OAuthDataSource:
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self):
raise NotImplementedError()
def get_access_token(self, code: str):
raise NotImplementedError()
class NotionOAuth(OAuthDataSource):
_AUTH_URL = "https://api.notion.com/v1/oauth/authorize"
_TOKEN_URL = "https://api.notion.com/v1/oauth/token"
_NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search"
_NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
_NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
def get_authorization_url(self):
params = {
"client_id": self.client_id,
"response_type": "code",
"redirect_uri": self.redirect_uri,
"owner": "user",
}
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
headers = {"Accept": "application/json"}
auth = (self.client_id, self.client_secret)
response = httpx.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
if not access_token:
raise ValueError(f"Error in Notion OAuth: {response_json}")
workspace_name = response_json.get("workspace_name")
workspace_icon = response_json.get("workspace_icon")
workspace_id = response_json.get("workspace_id")
# get all authorized pages
pages = self.get_authorized_pages(access_token)
source_info = {
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
"workspace_id": workspace_id,
"pages": pages,
"total": len(pages),
}
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=source_info,
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
def save_internal_access_token(self, access_token: str):
workspace_name = self.notion_workspace_name(access_token)
workspace_icon = None
workspace_id = current_user.current_tenant_id
# get all authorized pages
pages = self.get_authorized_pages(access_token)
source_info = {
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
"workspace_id": workspace_id,
"pages": pages,
"total": len(pages),
}
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=source_info,
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
def sync_data_source(self, binding_id: str):
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
)
)
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)
source_info = data_source_binding.source_info
new_source_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
"total": len(pages),
}
data_source_binding.source_info = new_source_info
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
raise ValueError("Data source binding not found")
def get_authorized_pages(self, access_token: str):
pages = []
page_results = self.notion_page_search(access_token)
database_results = self.notion_database_search(access_token)
# get page detail
for page_result in page_results:
page_id = page_result["id"]
page_name = "Untitled"
for key in page_result["properties"]:
if "title" in page_result["properties"][key] and page_result["properties"][key]["title"]:
title_list = page_result["properties"][key]["title"]
if len(title_list) > 0 and "plain_text" in title_list[0]:
page_name = title_list[0]["plain_text"]
page_icon = page_result["icon"]
if page_icon:
icon_type = page_icon["type"]
if icon_type in {"external", "file"}:
url = page_icon[icon_type]["url"]
icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"}
else:
icon = {"type": "emoji", "emoji": page_icon[icon_type]}
else:
icon = None
parent = page_result["parent"]
parent_type = parent["type"]
if parent_type == "block_id":
parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type])
elif parent_type == "workspace":
parent_id = "root"
else:
parent_id = parent[parent_type]
page = {
"page_id": page_id,
"page_name": page_name,
"page_icon": icon,
"parent_id": parent_id,
"type": "page",
}
pages.append(page)
# get database detail
for database_result in database_results:
page_id = database_result["id"]
if len(database_result["title"]) > 0:
page_name = database_result["title"][0]["plain_text"]
else:
page_name = "Untitled"
page_icon = database_result["icon"]
if page_icon:
icon_type = page_icon["type"]
if icon_type in {"external", "file"}:
url = page_icon[icon_type]["url"]
icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"}
else:
icon = {"type": icon_type, icon_type: page_icon[icon_type]}
else:
icon = None
parent = database_result["parent"]
parent_type = parent["type"]
if parent_type == "block_id":
parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type])
elif parent_type == "workspace":
parent_id = "root"
else:
parent_id = parent[parent_type]
page = {
"page_id": page_id,
"page_name": page_name,
"page_icon": icon,
"parent_id": parent_id,
"type": "database",
}
pages.append(page)
return pages
def notion_page_search(self, access_token: str):
results = []
next_cursor = None
has_more = True
while has_more:
data: dict[str, Any] = {
"filter": {"value": "page", "property": "object"},
**({"start_cursor": next_cursor} if next_cursor else {}),
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json()
results.extend(response_json.get("results", []))
has_more = response_json.get("has_more", False)
next_cursor = response_json.get("next_cursor", None)
return results
def notion_block_parent_page_id(self, access_token: str, block_id: str):
headers = {
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = httpx.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response_json = response.json()
if response.status_code != 200:
message = response_json.get("message", "unknown error")
raise ValueError(f"Error fetching block parent page ID: {message}")
parent = response_json["parent"]
parent_type = parent["type"]
if parent_type == "block_id":
return self.notion_block_parent_page_id(access_token, parent[parent_type])
return parent[parent_type]
def notion_workspace_name(self, access_token: str):
headers = {
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = httpx.get(url=self._NOTION_BOT_USER, headers=headers)
response_json = response.json()
if "object" in response_json and response_json["object"] == "user":
user_type = response_json["type"]
user_info = response_json[user_type]
if "workspace_name" in user_info:
return user_info["workspace_name"]
return "workspace"
def notion_database_search(self, access_token: str):
results = []
next_cursor = None
has_more = True
while has_more:
data: dict[str, Any] = {
"filter": {"value": "database", "property": "object"},
**({"start_cursor": next_cursor} if next_cursor else {}),
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json()
results.extend(response_json.get("results", []))
has_more = response_json.get("has_more", False)
next_cursor = response_json.get("next_cursor", None)
return results

11
dify/api/libs/orjson.py Normal file
View File

@@ -0,0 +1,11 @@
from typing import Any
import orjson
def orjson_dumps(
obj: Any,
encoding: str = "utf-8",
option: int | None = None,
) -> str:
return orjson.dumps(obj, option=option).decode(encoding)

24
dify/api/libs/passport.py Normal file
View File

@@ -0,0 +1,24 @@
import jwt
from werkzeug.exceptions import Unauthorized
from configs import dify_config
class PassportService:
def __init__(self):
self.sk = dify_config.SECRET_KEY
def issue(self, payload):
return jwt.encode(payload, self.sk, algorithm="HS256")
def verify(self, token):
try:
return jwt.decode(token, self.sk, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
raise Unauthorized("Token has expired.")
except jwt.InvalidSignatureError:
raise Unauthorized("Invalid token signature.")
except jwt.DecodeError:
raise Unauthorized("Invalid token.")
except jwt.PyJWTError: # Catch-all for other JWT errors
raise Unauthorized("Invalid token.")

26
dify/api/libs/password.py Normal file
View File

@@ -0,0 +1,26 @@
import base64
import binascii
import hashlib
import re
password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$"
def valid_password(password):
# Define a regex pattern for password rules
pattern = password_pattern
# Check if the password matches the pattern
if re.match(pattern, password) is not None:
return password
raise ValueError("Password must contain letters and numbers, and the length must be greater than 8.")
def hash_password(password_str, salt_byte):
dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000)
return binascii.hexlify(dk)
def compare_password(password_str, password_hashed_base64, salt_base64):
# compare password for login
return hash_password(password_str, base64.b64decode(salt_base64)) == base64.b64decode(password_hashed_base64)

94
dify/api/libs/rsa.py Normal file
View File

@@ -0,0 +1,94 @@
import hashlib
from typing import Union
from Crypto.Cipher import AES
from Crypto.PublicKey import RSA
from Crypto.Random import get_random_bytes
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import gmpy2_pkcs10aep_cipher
def generate_key_pair(tenant_id: str) -> str:
private_key = RSA.generate(2048)
public_key = private_key.publickey()
pem_private = private_key.export_key()
pem_public = public_key.export_key()
filepath = f"privkeys/{tenant_id}/private.pem"
storage.save(filepath, pem_private)
return pem_public.decode()
prefix_hybrid = b"HYBRID:"
def encrypt(text: str, public_key: Union[str, bytes]) -> bytes:
if isinstance(public_key, str):
public_key = public_key.encode()
aes_key = get_random_bytes(16)
cipher_aes = AES.new(aes_key, AES.MODE_EAX)
ciphertext, tag = cipher_aes.encrypt_and_digest(text.encode())
rsa_key = RSA.import_key(public_key)
cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key)
enc_aes_key: bytes = cipher_rsa.encrypt(aes_key)
encrypted_data = enc_aes_key + cipher_aes.nonce + tag + ciphertext
return prefix_hybrid + encrypted_data
def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]:
filepath = f"privkeys/{tenant_id}/private.pem"
cache_key = f"tenant_privkey:{hashlib.sha3_256(filepath.encode()).hexdigest()}"
private_key = redis_client.get(cache_key)
if not private_key:
try:
private_key = storage.load(filepath)
except FileNotFoundError:
raise PrivkeyNotFoundError(f"Private key not found, tenant_id: {tenant_id}")
redis_client.setex(cache_key, 120, private_key)
rsa_key = RSA.import_key(private_key)
cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key)
return rsa_key, cipher_rsa
def decrypt_token_with_decoding(encrypted_text: bytes, rsa_key: RSA.RsaKey, cipher_rsa) -> str:
if encrypted_text.startswith(prefix_hybrid):
encrypted_text = encrypted_text[len(prefix_hybrid) :]
enc_aes_key = encrypted_text[: rsa_key.size_in_bytes()]
nonce = encrypted_text[rsa_key.size_in_bytes() : rsa_key.size_in_bytes() + 16]
tag = encrypted_text[rsa_key.size_in_bytes() + 16 : rsa_key.size_in_bytes() + 32]
ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32 :]
aes_key = cipher_rsa.decrypt(enc_aes_key)
cipher_aes = AES.new(aes_key, AES.MODE_EAX, nonce=nonce)
decrypted_text = cipher_aes.decrypt_and_verify(ciphertext, tag)
else:
decrypted_text = cipher_rsa.decrypt(encrypted_text)
return decrypted_text.decode()
def decrypt(encrypted_text: bytes, tenant_id: str) -> str:
rsa_key, cipher_rsa = get_decrypt_decoding(tenant_id)
return decrypt_token_with_decoding(encrypted_text=encrypted_text, rsa_key=rsa_key, cipher_rsa=cipher_rsa)
class PrivkeyNotFoundError(Exception):
pass

View File

@@ -0,0 +1,108 @@
from datetime import UTC, datetime
import pytz
from croniter import croniter
def calculate_next_run_at(
cron_expression: str,
timezone: str,
base_time: datetime | None = None,
) -> datetime:
"""
Calculate the next run time for a cron expression in a specific timezone.
Args:
cron_expression: Standard 5-field cron expression or predefined expression
timezone: Timezone string (e.g., 'UTC', 'America/New_York')
base_time: Base time to calculate from (defaults to current UTC time)
Returns:
Next run time in UTC
Note:
Supports enhanced cron syntax including:
- Month abbreviations: JAN, FEB, MAR-JUN, JAN,JUN,DEC
- Day abbreviations: MON, TUE, MON-FRI, SUN,WED,FRI
- Predefined expressions: @daily, @weekly, @monthly, @yearly, @hourly
- Special characters: ? wildcard, L (last day), Sunday as 7
- Standard 5-field format only (minute hour day month dayOfWeek)
"""
# Validate cron expression format to match frontend behavior
parts = cron_expression.strip().split()
# Support both 5-field format and predefined expressions (matching frontend)
if len(parts) != 5 and not cron_expression.startswith("@"):
raise ValueError(
f"Cron expression must have exactly 5 fields or be a predefined expression "
f"(@daily, @weekly, etc.). Got {len(parts)} fields: '{cron_expression}'"
)
tz = pytz.timezone(timezone)
if base_time is None:
base_time = datetime.now(UTC)
base_time_tz = base_time.astimezone(tz)
cron = croniter(cron_expression, base_time_tz)
next_run_tz = cron.get_next(datetime)
next_run_utc = next_run_tz.astimezone(UTC)
return next_run_utc
def convert_12h_to_24h(time_str: str) -> tuple[int, int]:
"""
Parse 12-hour time format to 24-hour format for cron compatibility.
Args:
time_str: Time string in format "HH:MM AM/PM" (e.g., "12:30 PM")
Returns:
Tuple of (hour, minute) in 24-hour format
Raises:
ValueError: If time string format is invalid or values are out of range
Examples:
- "12:00 AM" -> (0, 0) # Midnight
- "12:00 PM" -> (12, 0) # Noon
- "1:30 PM" -> (13, 30)
- "11:59 PM" -> (23, 59)
"""
if not time_str or not time_str.strip():
raise ValueError("Time string cannot be empty")
parts = time_str.strip().split()
if len(parts) != 2:
raise ValueError(f"Invalid time format: '{time_str}'. Expected 'HH:MM AM/PM'")
time_part, period = parts
period = period.upper()
if period not in ["AM", "PM"]:
raise ValueError(f"Invalid period: '{period}'. Must be 'AM' or 'PM'")
time_parts = time_part.split(":")
if len(time_parts) != 2:
raise ValueError(f"Invalid time format: '{time_part}'. Expected 'HH:MM'")
try:
hour = int(time_parts[0])
minute = int(time_parts[1])
except ValueError as e:
raise ValueError(f"Invalid time values: {e}")
if hour < 1 or hour > 12:
raise ValueError(f"Invalid hour: {hour}. Must be between 1 and 12")
if minute < 0 or minute > 59:
raise ValueError(f"Invalid minute: {minute}. Must be between 0 and 59")
# Handle 12-hour to 24-hour edge cases
if period == "PM" and hour != 12:
hour += 12
elif period == "AM" and hour == 12:
hour = 0
return hour, minute

47
dify/api/libs/sendgrid.py Normal file
View File

@@ -0,0 +1,47 @@
import logging
import sendgrid
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
from sendgrid.helpers.mail import Content, Email, Mail, To
logger = logging.getLogger(__name__)
class SendGridClient:
def __init__(self, sendgrid_api_key: str, _from: str):
self.sendgrid_api_key = sendgrid_api_key
self._from = _from
def send(self, mail: dict):
logger.debug("Sending email with SendGrid")
_to = ""
try:
_to = mail["to"]
if not _to:
raise ValueError("SendGridClient: Cannot send email: recipient address is missing.")
sg = sendgrid.SendGridAPIClient(api_key=self.sendgrid_api_key)
from_email = Email(self._from)
to_email = To(_to)
subject = mail["subject"]
content = Content("text/html", mail["html"])
sg_mail = Mail(from_email, to_email, subject, content)
mail_json = sg_mail.get()
response = sg.client.mail.send.post(request_body=mail_json) # type: ignore
logger.debug(response.status_code)
logger.debug(response.body)
logger.debug(response.headers)
except TimeoutError:
logger.exception("SendGridClient Timeout occurred while sending email")
raise
except (UnauthorizedError, ForbiddenError):
logger.exception(
"SendGridClient Authentication failed. "
"Verify that your credentials and the 'from' email address are correct"
)
raise
except Exception:
logger.exception("SendGridClient Unexpected error occurred while sending email to %s", _to)
raise

59
dify/api/libs/smtp.py Normal file
View File

@@ -0,0 +1,59 @@
import logging
import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
logger = logging.getLogger(__name__)
class SMTPClient:
def __init__(
self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False
):
self.server = server
self.port = port
self._from = _from
self.username = username
self.password = password
self.use_tls = use_tls
self.opportunistic_tls = opportunistic_tls
def send(self, mail: dict):
smtp = None
try:
if self.use_tls:
if self.opportunistic_tls:
smtp = smtplib.SMTP(self.server, self.port, timeout=10)
# Send EHLO command with the HELO domain name as the server address
smtp.ehlo(self.server)
smtp.starttls()
# Resend EHLO command to identify the TLS session
smtp.ehlo(self.server)
else:
smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10)
else:
smtp = smtplib.SMTP(self.server, self.port, timeout=10)
# Only authenticate if both username and password are non-empty
if self.username and self.password and self.username.strip() and self.password.strip():
smtp.login(self.username, self.password)
msg = MIMEMultipart()
msg["Subject"] = mail["subject"]
msg["From"] = self._from
msg["To"] = mail["to"]
msg.attach(MIMEText(mail["html"], "html"))
smtp.sendmail(self._from, mail["to"], msg.as_string())
except smtplib.SMTPException:
logger.exception("SMTP error occurred")
raise
except TimeoutError:
logger.exception("Timeout occurred while sending email")
raise
except Exception:
logger.exception("Unexpected error occurred while sending email to %s", mail["to"])
raise
finally:
if smtp:
smtp.quit()

View File

@@ -0,0 +1,67 @@
"""Time duration parser utility."""
import re
from datetime import UTC, datetime, timedelta
def parse_time_duration(duration_str: str) -> timedelta | None:
"""
Parse time duration string to timedelta.
Supported formats:
- 7d: 7 days
- 4h: 4 hours
- 30m: 30 minutes
- 30s: 30 seconds
Args:
duration_str: Duration string (e.g., "7d", "4h", "30m", "30s")
Returns:
timedelta object or None if invalid format
"""
if not duration_str:
return None
# Pattern: number followed by unit (d, h, m, s)
pattern = r"^(\d+)([dhms])$"
match = re.match(pattern, duration_str.lower())
if not match:
return None
value = int(match.group(1))
unit = match.group(2)
if unit == "d":
return timedelta(days=value)
elif unit == "h":
return timedelta(hours=value)
elif unit == "m":
return timedelta(minutes=value)
elif unit == "s":
return timedelta(seconds=value)
return None
def get_time_threshold(duration_str: str | None) -> datetime | None:
"""
Get datetime threshold from duration string.
Calculates the datetime that is duration_str ago from now.
Args:
duration_str: Duration string (e.g., "7d", "4h", "30m", "30s")
Returns:
datetime object representing the threshold time, or None if no duration
"""
if not duration_str:
return None
duration = parse_time_duration(duration_str)
if duration is None:
return None
return datetime.now(UTC) - duration

231
dify/api/libs/token.py Normal file
View File

@@ -0,0 +1,231 @@
import logging
import re
from datetime import UTC, datetime, timedelta
from flask import Request
from werkzeug.exceptions import Unauthorized
from werkzeug.wrappers import Response
from configs import dify_config
from constants import (
COOKIE_NAME_ACCESS_TOKEN,
COOKIE_NAME_CSRF_TOKEN,
COOKIE_NAME_PASSPORT,
COOKIE_NAME_REFRESH_TOKEN,
COOKIE_NAME_WEBAPP_ACCESS_TOKEN,
HEADER_NAME_CSRF_TOKEN,
HEADER_NAME_PASSPORT,
)
from libs.passport import PassportService
logger = logging.getLogger(__name__)
CSRF_WHITE_LIST = [
re.compile(r"/console/api/apps/[a-f0-9-]+/workflows/draft"),
]
# server is behind a reverse proxy, so we need to check the url
def is_secure() -> bool:
return dify_config.CONSOLE_WEB_URL.startswith("https") and dify_config.CONSOLE_API_URL.startswith("https")
def _cookie_domain() -> str | None:
"""
Returns the normalized cookie domain.
Leading dots are stripped from the configured domain. Historically, a leading dot
indicated that a cookie should be sent to all subdomains, but modern browsers treat
'example.com' and '.example.com' identically. This normalization ensures consistent
behavior and avoids confusion.
"""
domain = dify_config.COOKIE_DOMAIN.strip()
domain = domain.removeprefix(".")
return domain or None
def _real_cookie_name(cookie_name: str) -> str:
if is_secure() and _cookie_domain() is None:
return "__Host-" + cookie_name
else:
return cookie_name
def _try_extract_from_header(request: Request) -> str | None:
auth_header = request.headers.get("Authorization")
if auth_header:
if " " not in auth_header:
return None
else:
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
return None
else:
return auth_token
return None
def extract_refresh_token(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_REFRESH_TOKEN))
def extract_csrf_token(request: Request) -> str | None:
return request.headers.get(HEADER_NAME_CSRF_TOKEN)
def extract_csrf_token_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
def extract_access_token(request: Request) -> str | None:
def _try_extract_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
def extract_webapp_access_token(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request)
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
def _try_extract_passport_token_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code))
def _try_extract_passport_token_from_header(request: Request) -> str | None:
return request.headers.get(HEADER_NAME_PASSPORT)
ret = _try_extract_passport_token_from_cookie(request) or _try_extract_passport_token_from_header(request)
return ret
def set_access_token_to_cookie(request: Request, response: Response, token: str, samesite: str = "Lax"):
response.set_cookie(
_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN),
value=token,
httponly=True,
domain=_cookie_domain(),
secure=is_secure(),
samesite=samesite,
max_age=int(dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 60),
path="/",
)
def set_refresh_token_to_cookie(request: Request, response: Response, token: str):
response.set_cookie(
_real_cookie_name(COOKIE_NAME_REFRESH_TOKEN),
value=token,
httponly=True,
domain=_cookie_domain(),
secure=is_secure(),
samesite="Lax",
max_age=int(60 * 60 * 24 * dify_config.REFRESH_TOKEN_EXPIRE_DAYS),
path="/",
)
def set_csrf_token_to_cookie(request: Request, response: Response, token: str):
response.set_cookie(
_real_cookie_name(COOKIE_NAME_CSRF_TOKEN),
value=token,
httponly=False,
domain=_cookie_domain(),
secure=is_secure(),
samesite="Lax",
max_age=int(60 * dify_config.ACCESS_TOKEN_EXPIRE_MINUTES),
path="/",
)
def _clear_cookie(
response: Response,
cookie_name: str,
samesite: str = "Lax",
http_only: bool = True,
):
response.set_cookie(
_real_cookie_name(cookie_name),
"",
expires=0,
path="/",
domain=_cookie_domain(),
secure=is_secure(),
httponly=http_only,
samesite=samesite,
)
def clear_access_token_from_cookie(response: Response, samesite: str = "Lax"):
_clear_cookie(response, COOKIE_NAME_ACCESS_TOKEN, samesite)
def clear_webapp_access_token_from_cookie(response: Response, samesite: str = "Lax"):
_clear_cookie(response, COOKIE_NAME_WEBAPP_ACCESS_TOKEN, samesite)
def clear_refresh_token_from_cookie(response: Response):
_clear_cookie(response, COOKIE_NAME_REFRESH_TOKEN)
def clear_csrf_token_from_cookie(response: Response):
_clear_cookie(response, COOKIE_NAME_CSRF_TOKEN, http_only=False)
def build_force_logout_cookie_headers() -> list[str]:
"""
Generate Set-Cookie header values that clear all auth-related cookies.
This mirrors the behavior of the standard cookie clearing helpers while
allowing callers that do not have a Response instance to reuse the logic.
"""
response = Response()
clear_access_token_from_cookie(response)
clear_csrf_token_from_cookie(response)
clear_refresh_token_from_cookie(response)
return response.headers.getlist("Set-Cookie")
def check_csrf_token(request: Request, user_id: str):
# some apis are sent by beacon, so we need to bypass csrf token check
# since these APIs are post, they are already protected by SameSite: Lax, so csrf is not required.
def _unauthorized():
raise Unauthorized("CSRF token is missing or invalid.")
for pattern in CSRF_WHITE_LIST:
if pattern.match(request.path):
return
csrf_token = extract_csrf_token(request)
csrf_token_from_cookie = extract_csrf_token_from_cookie(request)
if csrf_token != csrf_token_from_cookie:
_unauthorized()
if not csrf_token:
_unauthorized()
verified = {}
try:
verified = PassportService().verify(csrf_token)
except:
_unauthorized()
if verified.get("sub") != user_id:
_unauthorized()
exp: int | None = verified.get("exp")
if not exp:
_unauthorized()
else:
time_now = int(datetime.now().timestamp())
if exp < time_now:
_unauthorized()
def generate_csrf_token(user_id: str) -> str:
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
payload = {
"exp": int(exp_dt.timestamp()),
"sub": user_id,
}
return PassportService().issue(payload)

9
dify/api/libs/typing.py Normal file
View File

@@ -0,0 +1,9 @@
from typing import TypeGuard
def is_str_dict(v: object) -> TypeGuard[dict[str, object]]:
return isinstance(v, dict)
def is_str(v: object) -> TypeGuard[str]:
return isinstance(v, str)

164
dify/api/libs/uuid_utils.py Normal file
View File

@@ -0,0 +1,164 @@
import secrets
import struct
import time
import uuid
# Reference for UUIDv7 specification:
# RFC 9562, Section 5.7 - https://www.rfc-editor.org/rfc/rfc9562.html#section-5.7
# Define the format for packing the timestamp as an unsigned 64-bit integer (big-endian).
#
# For details on the `struct.pack` format, refer to:
# https://docs.python.org/3/library/struct.html#byte-order-size-and-alignment
_PACK_TIMESTAMP = ">Q"
# Define the format for packing the 12-bit random data A (as specified in RFC 9562 Section 5.7)
# into an unsigned 16-bit integer (big-endian).
_PACK_RAND_A = ">H"
def _create_uuidv7_bytes(timestamp_ms: int, random_bytes: bytes) -> bytes:
"""Create UUIDv7 byte structure with given timestamp and random bytes.
This is a private helper function that handles the common logic for creating
UUIDv7 byte structure according to RFC 9562 specification.
UUIDv7 Structure:
- 48 bits: timestamp (milliseconds since Unix epoch)
- 12 bits: random data A (with version bits)
- 62 bits: random data B (with variant bits)
The function performs the following operations:
1. Creates a 128-bit (16-byte) UUID structure
2. Packs the timestamp into the first 48 bits (6 bytes)
3. Sets the version bits to 7 (0111) in the correct position
4. Sets the variant bits to 10 (binary) in the correct position
5. Fills the remaining bits with the provided random bytes
Args:
timestamp_ms: The timestamp in milliseconds since Unix epoch (48 bits).
random_bytes: Random bytes to use for the random portions (must be 10 bytes).
First 2 bytes are used for random data A (12 bits after version).
Last 8 bytes are used for random data B (62 bits after variant).
Returns:
A 16-byte bytes object representing the complete UUIDv7 structure.
Note:
This function assumes the random_bytes parameter is exactly 10 bytes.
The caller is responsible for providing appropriate random data.
"""
# Create the 128-bit UUID structure
uuid_bytes = bytearray(16)
# Pack timestamp (48 bits) into first 6 bytes
uuid_bytes[0:6] = struct.pack(_PACK_TIMESTAMP, timestamp_ms)[2:8] # Take last 6 bytes of 8-byte big-endian
# Next 16 bits: random data A (12 bits) + version (4 bits)
# Take first 2 random bytes and set version to 7
rand_a = struct.unpack(_PACK_RAND_A, random_bytes[0:2])[0]
# Clear the highest 4 bits to make room for the version field
# by performing a bitwise AND with 0x0FFF (binary: 0b0000_1111_1111_1111).
rand_a = rand_a & 0x0FFF
# Set the version field to 7 (binary: 0111) by performing a bitwise OR with 0x7000 (binary: 0b0111_0000_0000_0000).
rand_a = rand_a | 0x7000
uuid_bytes[6:8] = struct.pack(_PACK_RAND_A, rand_a)
# Last 64 bits: random data B (62 bits) + variant (2 bits)
# Use remaining 8 random bytes and set variant to 10 (binary)
uuid_bytes[8:16] = random_bytes[2:10]
# Set variant bits (first 2 bits of byte 8 should be '10')
uuid_bytes[8] = (uuid_bytes[8] & 0x3F) | 0x80 # Set variant to 10xxxxxx
return bytes(uuid_bytes)
def uuidv7(timestamp_ms: int | None = None) -> uuid.UUID:
"""Generate a UUID version 7 according to RFC 9562 specification.
UUIDv7 features a time-ordered value field derived from the widely
implemented and well known Unix Epoch timestamp source, the number of
milliseconds since midnight 1 Jan 1970 UTC, leap seconds excluded.
Structure:
- 48 bits: timestamp (milliseconds since Unix epoch)
- 12 bits: random data A (with version bits)
- 62 bits: random data B (with variant bits)
Args:
timestamp_ms: The timestamp used when generating UUID, use the current time if unspecified.
Should be an integer representing milliseconds since Unix epoch.
Returns:
A UUID object representing a UUIDv7.
Example:
>>> import time
>>> # Generate UUIDv7 with current time
>>> uuid_current = uuidv7()
>>> # Generate UUIDv7 with specific timestamp
>>> uuid_specific = uuidv7(int(time.time() * 1000))
"""
if timestamp_ms is None:
timestamp_ms = int(time.time() * 1000)
# Generate 10 random bytes for the random portions
random_bytes = secrets.token_bytes(10)
# Create UUIDv7 bytes using the helper function
uuid_bytes = _create_uuidv7_bytes(timestamp_ms, random_bytes)
return uuid.UUID(bytes=uuid_bytes)
def uuidv7_timestamp(id_: uuid.UUID) -> int:
"""Extract the timestamp from a UUIDv7.
UUIDv7 contains a 48-bit timestamp field representing milliseconds since
the Unix epoch (1970-01-01 00:00:00 UTC). This function extracts and
returns that timestamp as an integer representing milliseconds since the epoch.
Args:
id_: A UUID object that should be a UUIDv7 (version 7).
Returns:
The timestamp as an integer representing milliseconds since Unix epoch.
Raises:
ValueError: If the provided UUID is not version 7.
Example:
>>> uuid_v7 = uuidv7()
>>> timestamp = uuidv7_timestamp(uuid_v7)
>>> print(f"UUID was created at: {timestamp} ms")
"""
# Verify this is a UUIDv7
if id_.version != 7:
raise ValueError(f"Expected UUIDv7 (version 7), got version {id_.version}")
# Extract the UUID bytes
uuid_bytes = id_.bytes
# Extract the first 48 bits (6 bytes) as the timestamp in milliseconds
# Pad with 2 zero bytes at the beginning to make it 8 bytes for unpacking as Q (unsigned long long)
timestamp_bytes = b"\x00\x00" + uuid_bytes[0:6]
ts_in_ms = struct.unpack(_PACK_TIMESTAMP, timestamp_bytes)[0]
# Return timestamp directly in milliseconds as integer
assert isinstance(ts_in_ms, int)
return ts_in_ms
def uuidv7_boundary(timestamp_ms: int) -> uuid.UUID:
"""Generate a non-random uuidv7 with the given timestamp (first 48 bits) and
all random bits to 0. As the smallest possible uuidv7 for that timestamp,
it may be used as a boundary for partitions.
"""
# Use zero bytes for all random portions
zero_random_bytes = b"\x00" * 10
# Create UUIDv7 bytes using the helper function
uuid_bytes = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes)
return uuid.UUID(bytes=uuid_bytes)

View File

@@ -0,0 +1,5 @@
def validate_description_length(description: str | None) -> str | None:
"""Validate description length."""
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description