This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

@@ -0,0 +1,65 @@
import math
import time
import click
import app
from extensions.ext_database import db
from models.account import TenantPluginAutoUpgradeStrategy
from tasks import process_tenant_plugin_autoupgrade_check_task as check_task
AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL = 15 * 60 # 15 minutes
MAX_CONCURRENT_CHECK_TASKS = 20
@app.celery.task(queue="plugin")
def check_upgradable_plugin_task():
click.echo(click.style("Start check upgradable plugin.", fg="green"))
start_at = time.perf_counter()
now_seconds_of_day = time.time() % 86400 - 30 # we assume the tz is UTC
click.echo(click.style(f"Now seconds of day: {now_seconds_of_day}", fg="green"))
strategies = (
db.session.query(TenantPluginAutoUpgradeStrategy)
.where(
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day,
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day
< now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL,
TenantPluginAutoUpgradeStrategy.strategy_setting
!= TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
)
.all()
)
total_strategies = len(strategies)
click.echo(click.style(f"Total strategies: {total_strategies}", fg="green"))
batch_chunk_count = math.ceil(
total_strategies / MAX_CONCURRENT_CHECK_TASKS
) # make sure all strategies are checked in this interval
batch_interval_time = (AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL / batch_chunk_count) if batch_chunk_count > 0 else 0
for i in range(0, total_strategies, MAX_CONCURRENT_CHECK_TASKS):
batch_strategies = strategies[i : i + MAX_CONCURRENT_CHECK_TASKS]
for strategy in batch_strategies:
check_task.process_tenant_plugin_autoupgrade_check_task.delay(
strategy.tenant_id,
strategy.strategy_setting,
strategy.upgrade_time_of_day,
strategy.upgrade_mode,
strategy.exclude_plugins,
strategy.include_plugins,
)
# Only sleep if batch_interval_time > 0.0001 AND current batch is not the last one
if batch_interval_time > 0.0001 and i + MAX_CONCURRENT_CHECK_TASKS < total_strategies:
time.sleep(batch_interval_time)
end_at = time.perf_counter()
click.echo(
click.style(
f"Checked upgradable plugin success latency: {end_at - start_at}",
fg="green",
)
)

View File

@@ -0,0 +1,42 @@
import datetime
import time
import click
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
import app
from configs import dify_config
from extensions.ext_database import db
from models.dataset import Embedding
@app.celery.task(queue="dataset")
def clean_embedding_cache_task():
click.echo(click.style("Start clean embedding cache.", fg="green"))
clean_days = int(dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING)
start_at = time.perf_counter()
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
while True:
try:
embedding_ids = (
db.session.query(Embedding.id)
.where(Embedding.created_at < thirty_days_ago)
.order_by(Embedding.created_at.desc())
.limit(100)
.all()
)
embedding_ids = [embedding_id[0] for embedding_id in embedding_ids]
except SQLAlchemyError:
raise
if embedding_ids:
for embedding_id in embedding_ids:
db.session.execute(
text("DELETE FROM embeddings WHERE id = :embedding_id"), {"embedding_id": embedding_id}
)
db.session.commit()
else:
break
end_at = time.perf_counter()
click.echo(click.style(f"Cleaned embedding cache from db success latency: {end_at - start_at}", fg="green"))

View File

@@ -0,0 +1,90 @@
import datetime
import logging
import time
import click
from sqlalchemy.exc import SQLAlchemyError
import app
from configs import dify_config
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import (
App,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@app.celery.task(queue="dataset")
def clean_messages():
click.echo(click.style("Start clean messages.", fg="green"))
start_at = time.perf_counter()
plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta(
days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING
)
while True:
try:
# Main query with join and filter
messages = (
db.session.query(Message)
.where(Message.created_at < plan_sandbox_clean_message_day)
.order_by(Message.created_at.desc())
.limit(100)
.all()
)
except SQLAlchemyError:
raise
if not messages:
break
for message in messages:
app = db.session.query(App).filter_by(id=message.app_id).first()
if not app:
logger.warning(
"Expected App record to exist, but none was found, app_id=%s, message_id=%s",
message.app_id,
message.id,
)
continue
features_cache_key = f"features:{app.tenant_id}"
plan_cache = redis_client.get(features_cache_key)
if plan_cache is None:
features = FeatureService.get_features(app.tenant_id)
redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
plan = features.billing.subscription.plan
else:
plan = plan_cache.decode()
if plan == CloudPlan.SANDBOX:
# clean related message
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(Message).where(Message.id == message.id).delete()
db.session.commit()
end_at = time.perf_counter()
click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green"))

View File

@@ -0,0 +1,156 @@
import datetime
import time
from typing import TypedDict
import click
from sqlalchemy import func, select
from sqlalchemy.exc import SQLAlchemyError
import app
from configs import dify_config
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Document
from services.feature_service import FeatureService
class CleanupConfig(TypedDict):
clean_day: datetime.datetime
plan_filter: str | None
add_logs: bool
@app.celery.task(queue="dataset")
def clean_unused_datasets_task():
click.echo(click.style("Start clean unused datasets indexes.", fg="green"))
start_at = time.perf_counter()
# Define cleanup configurations
cleanup_configs: list[CleanupConfig] = [
{
"clean_day": datetime.datetime.now() - datetime.timedelta(days=dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING),
"plan_filter": None,
"add_logs": True,
},
{
"clean_day": datetime.datetime.now() - datetime.timedelta(days=dify_config.PLAN_PRO_CLEAN_DAY_SETTING),
"plan_filter": CloudPlan.SANDBOX,
"add_logs": False,
},
]
for config in cleanup_configs:
clean_day = config["clean_day"]
plan_filter = config["plan_filter"]
add_logs = config["add_logs"]
page = 1
while True:
try:
# Subquery for counting new documents
document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at > clean_day,
)
.group_by(Document.dataset_id)
.subquery()
)
# Subquery for counting old documents
document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at < clean_day,
)
.group_by(Document.dataset_id)
.subquery()
)
# Main query with join and filter
stmt = (
select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.where(
Dataset.created_at < clean_day,
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
)
.order_by(Dataset.created_at.desc())
)
datasets = db.paginate(stmt, page=page, per_page=50, error_out=False)
except SQLAlchemyError:
raise
if datasets is None or datasets.items is None or len(datasets.items) == 0:
break
for dataset in datasets:
dataset_query = db.session.scalars(
select(DatasetQuery).where(
DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id
)
).all()
if not dataset_query or len(dataset_query) == 0:
try:
should_clean = True
# Check plan filter if specified
if plan_filter:
features_cache_key = f"features:{dataset.tenant_id}"
plan_cache = redis_client.get(features_cache_key)
if plan_cache is None:
features = FeatureService.get_features(dataset.tenant_id)
redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
plan = features.billing.subscription.plan
else:
plan = plan_cache.decode()
should_clean = plan == plan_filter
if should_clean:
# Add auto disable log if required
if add_logs:
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
).all()
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# Remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None)
# Update document
db.session.query(Document).filter_by(dataset_id=dataset.id).update(
{Document.enabled: False}
)
db.session.commit()
click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green"))
except Exception as e:
click.echo(click.style(f"clean dataset index error: {e.__class__.__name__} {str(e)}", fg="red"))
page += 1
end_at = time.perf_counter()
click.echo(click.style(f"Cleaned unused dataset from db success latency: {end_at - start_at}", fg="green"))

View File

@@ -0,0 +1,148 @@
import datetime
import logging
import time
from collections.abc import Sequence
import click
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
import app
from configs import dify_config
from extensions.ext_database import db
from models.model import (
AppAnnotationHitHistory,
Conversation,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.workflow import ConversationVariable, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun
logger = logging.getLogger(__name__)
MAX_RETRIES = 3
BATCH_SIZE = dify_config.WORKFLOW_LOG_CLEANUP_BATCH_SIZE
@app.celery.task(queue="dataset")
def clean_workflow_runlogs_precise():
"""Clean expired workflow run logs with retry mechanism and complete message cascade"""
click.echo(click.style("Start clean workflow run logs (precise mode with complete cascade).", fg="green"))
start_at = time.perf_counter()
retention_days = dify_config.WORKFLOW_LOG_RETENTION_DAYS
cutoff_date = datetime.datetime.now() - datetime.timedelta(days=retention_days)
session_factory = sessionmaker(db.engine, expire_on_commit=False)
try:
with session_factory.begin() as session:
total_workflow_runs = session.query(WorkflowRun).where(WorkflowRun.created_at < cutoff_date).count()
if total_workflow_runs == 0:
logger.info("No expired workflow run logs found")
return
logger.info("Found %s expired workflow run logs to clean", total_workflow_runs)
total_deleted = 0
failed_batches = 0
batch_count = 0
while True:
with session_factory.begin() as session:
workflow_run_ids = session.scalars(
select(WorkflowRun.id)
.where(WorkflowRun.created_at < cutoff_date)
.order_by(WorkflowRun.created_at, WorkflowRun.id)
.limit(BATCH_SIZE)
).all()
if not workflow_run_ids:
break
batch_count += 1
success = _delete_batch(session, workflow_run_ids, failed_batches)
if success:
total_deleted += len(workflow_run_ids)
failed_batches = 0
else:
failed_batches += 1
if failed_batches >= MAX_RETRIES:
logger.error("Failed to delete batch after %s retries, aborting cleanup for today", MAX_RETRIES)
break
else:
# Calculate incremental delay times: 5, 10, 15 minutes
retry_delay_minutes = failed_batches * 5
logger.warning("Batch deletion failed, retrying in %s minutes...", retry_delay_minutes)
time.sleep(retry_delay_minutes * 60)
continue
logger.info("Cleanup completed: %s expired workflow run logs deleted", total_deleted)
except Exception:
logger.exception("Unexpected error in workflow log cleanup")
raise
end_at = time.perf_counter()
execution_time = end_at - start_at
click.echo(click.style(f"Cleaned workflow run logs from db success latency: {execution_time:.2f}s", fg="green"))
def _delete_batch(session: Session, workflow_run_ids: Sequence[str], attempt_count: int) -> bool:
"""Delete a single batch of workflow runs and all related data within a nested transaction."""
try:
with session.begin_nested():
message_data = (
session.query(Message.id, Message.conversation_id)
.where(Message.workflow_run_id.in_(workflow_run_ids))
.all()
)
message_id_list = [msg.id for msg in message_data]
conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id})
if message_id_list:
message_related_models = [
AppAnnotationHitHistory,
MessageAgentThought,
MessageChain,
MessageFile,
MessageAnnotation,
MessageFeedback,
]
for model in message_related_models:
session.query(model).where(model.message_id.in_(message_id_list)).delete(synchronize_session=False) # type: ignore
# error: "DeclarativeAttributeIntercept" has no attribute "message_id". But this type is only in lib
# and these 6 types all have the message_id field.
session.query(Message).where(Message.workflow_run_id.in_(workflow_run_ids)).delete(
synchronize_session=False
)
session.query(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)).delete(
synchronize_session=False
)
session.query(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids)
).delete(synchronize_session=False)
if conversation_id_list:
session.query(ConversationVariable).where(
ConversationVariable.conversation_id.in_(conversation_id_list)
).delete(synchronize_session=False)
session.query(Conversation).where(Conversation.id.in_(conversation_id_list)).delete(
synchronize_session=False
)
session.query(WorkflowRun).where(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False)
return True
except Exception:
logger.exception("Batch deletion failed (attempt %s)", attempt_count + 1)
return False

View File

@@ -0,0 +1,61 @@
import time
import click
import app
from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from extensions.ext_database import db
from models.dataset import TidbAuthBinding
@app.celery.task(queue="dataset")
def create_tidb_serverless_task():
click.echo(click.style("Start create tidb serverless task.", fg="green"))
if not dify_config.CREATE_TIDB_SERVICE_JOB_ENABLED:
return
tidb_serverless_number = dify_config.TIDB_SERVERLESS_NUMBER
start_at = time.perf_counter()
while True:
try:
# check the number of idle tidb serverless
idle_tidb_serverless_number = (
db.session.query(TidbAuthBinding).where(TidbAuthBinding.active == False).count()
)
if idle_tidb_serverless_number >= tidb_serverless_number:
break
# create tidb serverless
iterations_per_thread = 20
create_clusters(iterations_per_thread)
except Exception as e:
click.echo(click.style(f"Error: {e}", fg="red"))
break
end_at = time.perf_counter()
click.echo(click.style(f"Create tidb serverless task success latency: {end_at - start_at}", fg="green"))
def create_clusters(batch_size):
try:
# TODO: maybe we can set the default value for the following parameters in the config file
new_clusters = TidbService.batch_create_tidb_serverless_cluster(
batch_size=batch_size,
project_id=dify_config.TIDB_PROJECT_ID or "",
api_url=dify_config.TIDB_API_URL or "",
iam_url=dify_config.TIDB_IAM_API_URL or "",
public_key=dify_config.TIDB_PUBLIC_KEY or "",
private_key=dify_config.TIDB_PRIVATE_KEY or "",
region=dify_config.TIDB_REGION or "",
)
for new_cluster in new_clusters:
tidb_auth_binding = TidbAuthBinding(
cluster_id=new_cluster["cluster_id"],
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
)
db.session.add(tidb_auth_binding)
db.session.commit()
except Exception as e:
click.echo(click.style(f"Error: {e}", fg="red"))

View File

@@ -0,0 +1,98 @@
import logging
import time
from collections import defaultdict
import click
from sqlalchemy import select
import app
from configs import dify_config
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
from models import Account, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetAutoDisableLog
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@app.celery.task(queue="dataset")
def mail_clean_document_notify_task():
"""
Async Send document clean notify mail
Usage: mail_clean_document_notify_task.delay()
"""
if not mail.is_inited():
return
logger.info(click.style("Start send document clean notify mail", fg="green"))
start_at = time.perf_counter()
# send document clean notify mail
try:
dataset_auto_disable_logs = db.session.scalars(
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False)
).all()
# group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs:
if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map:
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = []
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
url = f"{dify_config.CONSOLE_WEB_URL}/datasets"
for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
features = FeatureService.get_features(tenant_id)
plan = features.billing.subscription.plan
if plan != CloudPlan.SANDBOX:
knowledge_details = []
# check tenant
tenant = db.session.query(Tenant).where(Tenant.id == tenant_id).first()
if not tenant:
continue
# check current owner
current_owner_join = (
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
)
if not current_owner_join:
continue
account = db.session.query(Account).where(Account.id == current_owner_join.account_id).first()
if not account:
continue
dataset_auto_dataset_map = {} # type: ignore
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map:
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = []
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
dataset_auto_disable_log.document_id
)
for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
if knowledge_details:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
language_code="en-US",
to=account.email,
template_context={
"userName": account.email,
"knowledge_details": knowledge_details,
"url": url,
},
)
# update notified to True
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
dataset_auto_disable_log.notified = True
db.session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("Send document clean notify mail failed")

View File

@@ -0,0 +1,73 @@
import logging
from datetime import datetime
import click
from kombu.utils.url import parse_url # type: ignore
from redis import Redis
import app
from configs import dify_config
from extensions.ext_database import db
from libs.email_i18n import EmailType, get_email_i18n_service
redis_config = parse_url(dify_config.CELERY_BROKER_URL)
celery_redis = Redis(
host=redis_config.get("hostname") or "localhost",
port=redis_config.get("port") or 6379,
password=redis_config.get("password") or None,
db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1,
)
logger = logging.getLogger(__name__)
@app.celery.task(queue="monitor")
def queue_monitor_task():
queue_name = "dataset"
threshold = dify_config.QUEUE_MONITOR_THRESHOLD
if threshold is None:
logger.warning(click.style("QUEUE_MONITOR_THRESHOLD is not configured, skipping monitoring", fg="yellow"))
return
try:
queue_length = celery_redis.llen(f"{queue_name}")
logger.info(click.style(f"Start monitor {queue_name}", fg="green"))
if queue_length is None:
logger.error(
click.style(f"Failed to get queue length for {queue_name} - Redis may be unavailable", fg="red")
)
return
logger.info(click.style(f"Queue length: {queue_length}", fg="green"))
if queue_length >= threshold:
warning_msg = f"Queue {queue_name} task count exceeded the limit.: {queue_length}/{threshold}"
logging.warning(click.style(warning_msg, fg="red"))
alert_emails = dify_config.QUEUE_MONITOR_ALERT_EMAILS
if alert_emails:
to_list = alert_emails.split(",")
email_service = get_email_i18n_service()
for to in to_list:
try:
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
email_service.send_email(
email_type=EmailType.QUEUE_MONITOR_ALERT,
language_code="en-US",
to=to,
template_context={
"queue_name": queue_name,
"queue_length": queue_length,
"threshold": threshold,
"alert_time": current_time,
},
)
except Exception:
logger.exception(click.style("Exception occurred during sending email", fg="red"))
except Exception:
logger.exception(click.style("Exception occurred during queue monitoring", fg="red"))
finally:
if db.session.is_active:
db.session.close()

View File

@@ -0,0 +1,104 @@
import logging
import math
import time
from collections.abc import Iterable, Sequence
from sqlalchemy import ColumnElement, and_, func, or_, select
from sqlalchemy.engine.row import Row
from sqlalchemy.orm import Session
import app
from configs import dify_config
from core.trigger.utils.locks import build_trigger_refresh_lock_keys
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.trigger import TriggerSubscription
from tasks.trigger_subscription_refresh_tasks import trigger_subscription_refresh
logger = logging.getLogger(__name__)
def _now_ts() -> int:
return int(time.time())
def _build_due_filter(now_ts: int):
"""Build SQLAlchemy filter for due credential or subscription refresh."""
credential_due: ColumnElement[bool] = and_(
TriggerSubscription.credential_expires_at != -1,
TriggerSubscription.credential_expires_at
<= now_ts + int(dify_config.TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS),
)
subscription_due: ColumnElement[bool] = and_(
TriggerSubscription.expires_at != -1,
TriggerSubscription.expires_at <= now_ts + int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS),
)
return or_(credential_due, subscription_due)
def _acquire_locks(keys: Iterable[str], ttl_seconds: int) -> list[bool]:
"""Attempt to acquire locks in a single pipelined round-trip.
Returns a list of booleans indicating which locks were acquired.
"""
pipe = redis_client.pipeline(transaction=False)
for key in keys:
pipe.set(key, b"1", ex=ttl_seconds, nx=True)
results = pipe.execute()
return [bool(r) for r in results]
@app.celery.task(queue="trigger_refresh_publisher")
def trigger_provider_refresh() -> None:
"""
Scan due trigger subscriptions and enqueue refresh tasks with in-flight locks.
"""
now: int = _now_ts()
batch_size: int = int(dify_config.TRIGGER_PROVIDER_REFRESH_BATCH_SIZE)
lock_ttl: int = max(300, int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS))
with Session(db.engine, expire_on_commit=False) as session:
filter: ColumnElement[bool] = _build_due_filter(now_ts=now)
total_due: int = int(session.scalar(statement=select(func.count()).where(filter)) or 0)
logger.info("Trigger refresh scan start: due=%d", total_due)
if total_due == 0:
return
pages: int = math.ceil(total_due / batch_size)
for page in range(pages):
offset: int = page * batch_size
subscription_rows: Sequence[Row[tuple[str, str]]] = session.execute(
select(TriggerSubscription.tenant_id, TriggerSubscription.id)
.where(filter)
.order_by(TriggerSubscription.updated_at.asc())
.offset(offset)
.limit(batch_size)
).all()
if not subscription_rows:
logger.debug("Trigger refresh page %d/%d empty", page + 1, pages)
continue
subscriptions: list[tuple[str, str]] = [
(str(tenant_id), str(subscription_id)) for tenant_id, subscription_id in subscription_rows
]
lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions)
acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl)
enqueued: int = 0
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired):
if not is_locked:
continue
trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id)
enqueued += 1
logger.info(
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d",
page + 1,
pages,
len(subscriptions),
sum(1 for x in acquired if x),
enqueued,
)
logger.info("Trigger refresh scan done: due=%d", total_due)

View File

@@ -0,0 +1,50 @@
import time
from collections.abc import Sequence
import click
from sqlalchemy import select
import app
from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from extensions.ext_database import db
from models.dataset import TidbAuthBinding
@app.celery.task(queue="dataset")
def update_tidb_serverless_status_task():
click.echo(click.style("Update tidb serverless status task.", fg="green"))
start_at = time.perf_counter()
try:
# check the number of idle tidb serverless
tidb_serverless_list = db.session.scalars(
select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
).all()
if len(tidb_serverless_list) == 0:
return
# update tidb serverless status
update_clusters(tidb_serverless_list)
except Exception as e:
click.echo(click.style(f"Error: {e}", fg="red"))
end_at = time.perf_counter()
click.echo(click.style(f"Update tidb serverless status task success latency: {end_at - start_at}", fg="green"))
def update_clusters(tidb_serverless_list: Sequence[TidbAuthBinding]):
try:
# batch 20
for i in range(0, len(tidb_serverless_list), 20):
items = tidb_serverless_list[i : i + 20]
# TODO: maybe we can set the default value for the following parameters in the config file
TidbService.batch_update_tidb_serverless_cluster_status(
tidb_serverless_list=items,
project_id=dify_config.TIDB_PROJECT_ID or "",
api_url=dify_config.TIDB_API_URL or "",
iam_url=dify_config.TIDB_IAM_API_URL or "",
public_key=dify_config.TIDB_PUBLIC_KEY or "",
private_key=dify_config.TIDB_PRIVATE_KEY or "",
)
except Exception as e:
click.echo(click.style(f"Error: {e}", fg="red"))

View File

@@ -0,0 +1,116 @@
import logging
from celery import group, shared_task
from sqlalchemy import and_, select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.schedule_utils import calculate_next_run_at
from models.trigger import AppTrigger, AppTriggerStatus, AppTriggerType, WorkflowSchedulePlan
from tasks.workflow_schedule_tasks import run_schedule_trigger
logger = logging.getLogger(__name__)
@shared_task(queue="schedule_poller")
def poll_workflow_schedules() -> None:
"""
Poll and process due workflow schedules.
Streaming flow:
1. Fetch due schedules in batches
2. Process each batch until all due schedules are handled
3. Optional: Limit total dispatches per tick as a circuit breaker
"""
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
total_dispatched = 0
# Process in batches until we've handled all due schedules or hit the limit
while True:
due_schedules = _fetch_due_schedules(session)
if not due_schedules:
break
dispatched_count = _process_schedules(session, due_schedules)
total_dispatched += dispatched_count
logger.debug("Batch processed: %d dispatched", dispatched_count)
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
if (
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK > 0
and total_dispatched >= dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK
):
logger.warning(
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
)
break
if total_dispatched > 0:
logger.info("Total processed: %d dispatched", total_dispatched)
def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
"""
Fetch a batch of due schedules, sorted by most overdue first.
Returns up to WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE schedules per call.
Used in a loop to progressively process all due schedules.
"""
now = naive_utc_now()
due_schedules = session.scalars(
(
select(WorkflowSchedulePlan)
.join(
AppTrigger,
and_(
AppTrigger.app_id == WorkflowSchedulePlan.app_id,
AppTrigger.node_id == WorkflowSchedulePlan.node_id,
AppTrigger.trigger_type == AppTriggerType.TRIGGER_SCHEDULE,
),
)
.where(
WorkflowSchedulePlan.next_run_at <= now,
WorkflowSchedulePlan.next_run_at.isnot(None),
AppTrigger.status == AppTriggerStatus.ENABLED,
)
)
.order_by(WorkflowSchedulePlan.next_run_at.asc())
.with_for_update(skip_locked=True)
.limit(dify_config.WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE)
)
return list(due_schedules)
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) -> int:
"""Process schedules: check quota, update next run time and dispatch to Celery in parallel."""
if not schedules:
return 0
tasks_to_dispatch: list[str] = []
for schedule in schedules:
next_run_at = calculate_next_run_at(
schedule.cron_expression,
schedule.timezone,
)
schedule.next_run_at = next_run_at
tasks_to_dispatch.append(schedule.id)
if tasks_to_dispatch:
job = group(run_schedule_trigger.s(schedule_id) for schedule_id in tasks_to_dispatch)
job.apply_async()
logger.debug("Dispatched %d tasks in parallel", len(tasks_to_dispatch))
session.commit()
return len(tasks_to_dispatch)