This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

View File

@@ -0,0 +1,117 @@
import logging
import time
import click
from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DatasetAutoDisableLog, DocumentSegment
from models.dataset import Document as DatasetDocument
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def add_document_to_index_task(dataset_document_id: str):
"""
Async Add document to index
:param dataset_document_id:
Usage: add_document_to_index_task.delay(dataset_document_id)
"""
logger.info(click.style(f"Start add document to index: {dataset_document_id}", fg="green"))
start_at = time.perf_counter()
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
if not dataset_document:
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
db.session.close()
return
if dataset_document.indexing_status != "completed":
db.session.close()
return
indexing_cache_key = f"document_{dataset_document.id}_indexing"
try:
dataset = dataset_document.dataset
if not dataset:
raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
)
.order_by(DocumentSegment.position.asc())
.all()
)
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents)
# delete auto disable log
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()
# update segment to enable
db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
{
DocumentSegment.enabled: True,
DocumentSegment.disabled_at: None,
DocumentSegment.disabled_by: None,
DocumentSegment.updated_at: naive_utc_now(),
}
)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
)
except Exception as e:
logger.exception("add document to index failed")
dataset_document.enabled = False
dataset_document.disabled_at = naive_utc_now()
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()

View File

@@ -0,0 +1,62 @@
import logging
import time
import click
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def add_annotation_to_index_task(
annotation_id: str, question: str, tenant_id: str, app_id: str, collection_binding_id: str
):
"""
Add annotation to index.
:param annotation_id: annotation id
:param question: question
:param tenant_id: tenant id
:param app_id: app id
:param collection_binding_id: embedding binding id
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
logger.info(click.style(f"Start build index for annotation: {annotation_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id, "annotation"
)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,
)
document = Document(
page_content=question, metadata={"annotation_id": annotation_id, "app_id": app_id, "doc_id": annotation_id}
)
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.create([document], duplicate_check=True)
end_at = time.perf_counter()
logger.info(
click.style(
f"Build index successful for annotation: {annotation_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@@ -0,0 +1,94 @@
import logging
import time
import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
from services.dataset_service import DatasetCollectionBindingService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str, user_id: str):
"""
Add annotation to index.
:param job_id: job_id
:param content_list: content list
:param app_id: app id
:param tenant_id: tenant id
:param user_id: user_id
"""
logger.info(click.style(f"Start batch import annotation: {job_id}", fg="green"))
start_at = time.perf_counter()
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
# get app info
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
if app:
try:
documents = []
for content in content_list:
annotation = MessageAnnotation(
app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id
)
db.session.add(annotation)
db.session.flush()
document = Document(
page_content=content["question"],
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
)
documents.append(document)
# if annotation reply is enabled , batch add annotations' index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if app_annotation_setting:
dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
app_annotation_setting.collection_binding_id, "annotation"
)
)
if not dataset_collection_binding:
raise NotFound("App annotation setting not found")
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,
)
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.create(documents, duplicate_check=True)
db.session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
"Build index successful for batch import annotation: {} latency: {}".format(
job_id, end_at - start_at
),
fg="green",
)
)
except Exception as e:
db.session.rollback()
redis_client.setex(indexing_cache_key, 600, "error")
indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
redis_client.setex(indexing_error_msg_key, 600, str(e))
logger.exception("Build index for batch import annotations failed")
finally:
db.session.close()

View File

@@ -0,0 +1,44 @@
import logging
import time
import click
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str, collection_binding_id: str):
"""
Async delete annotation index task
"""
logger.info(click.style(f"Start delete app annotation index: {app_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id, "annotation"
)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
collection_binding_id=dataset_collection_binding.id,
)
try:
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.delete_by_metadata_field("annotation_id", annotation_id)
except Exception:
logger.exception("Delete annotation index failed when annotation deleted.")
end_at = time.perf_counter()
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("Annotation deleted index failed")
finally:
db.session.close()

View File

@@ -0,0 +1,71 @@
import logging
import time
import click
from celery import shared_task
from sqlalchemy import exists, select
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
"""
Async enable annotation reply task
"""
logger.info(click.style(f"Start delete app annotations index: {app_id}", fg="green"))
start_at = time.perf_counter()
# get app info
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
if not app:
logger.info(click.style(f"App not found: {app_id}", fg="red"))
db.session.close()
return
app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if not app_annotation_setting:
logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red"))
db.session.close()
return
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
try:
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
collection_binding_id=app_annotation_setting.collection_binding_id,
)
try:
if annotations_exists:
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.delete()
except Exception:
logger.exception("Delete annotation index failed when annotation deleted.")
redis_client.setex(disable_app_annotation_job_key, 600, "completed")
# delete annotation setting
db.session.delete(app_annotation_setting)
db.session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("Annotation batch deleted index failed")
redis_client.setex(disable_app_annotation_job_key, 600, "error")
disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}"
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
finally:
redis_client.delete(disable_app_annotation_key)
db.session.close()

View File

@@ -0,0 +1,124 @@
import logging
import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
from services.dataset_service import DatasetCollectionBindingService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def enable_annotation_reply_task(
job_id: str,
app_id: str,
user_id: str,
tenant_id: str,
score_threshold: float,
embedding_provider_name: str,
embedding_model_name: str,
):
"""
Async enable annotation reply task
"""
logger.info(click.style(f"Start add app annotation to index: {app_id}", fg="green"))
start_at = time.perf_counter()
# get app info
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
if not app:
logger.info(click.style(f"App not found: {app_id}", fg="red"))
db.session.close()
return
annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
try:
documents = []
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation"
)
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
old_dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
annotation_setting.collection_binding_id, "annotation"
)
)
if old_dataset_collection_binding and annotations:
old_dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=old_dataset_collection_binding.provider_name,
embedding_model=old_dataset_collection_binding.model_name,
collection_binding_id=old_dataset_collection_binding.id,
)
old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
try:
old_vector.delete()
except Exception as e:
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
annotation_setting.score_threshold = score_threshold
annotation_setting.collection_binding_id = dataset_collection_binding.id
annotation_setting.updated_user_id = user_id
annotation_setting.updated_at = naive_utc_now()
db.session.add(annotation_setting)
else:
new_app_annotation_setting = AppAnnotationSetting(
app_id=app_id,
score_threshold=score_threshold,
collection_binding_id=dataset_collection_binding.id,
created_user_id=user_id,
updated_user_id=user_id,
)
db.session.add(new_app_annotation_setting)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id,
)
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question,
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
)
documents.append(document)
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
try:
vector.delete_by_metadata_field("app_id", app_id)
except Exception as e:
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
vector.create(documents)
db.session.commit()
redis_client.setex(enable_app_annotation_job_key, 600, "completed")
end_at = time.perf_counter()
logger.info(click.style(f"App annotations added to index: {app_id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("Annotation batch created index failed")
redis_client.setex(enable_app_annotation_job_key, 600, "error")
enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}"
redis_client.setex(enable_app_annotation_error_key, 600, str(e))
db.session.rollback()
finally:
redis_client.delete(enable_app_annotation_key)
db.session.close()

View File

@@ -0,0 +1,63 @@
import logging
import time
import click
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def update_annotation_to_index_task(
annotation_id: str, question: str, tenant_id: str, app_id: str, collection_binding_id: str
):
"""
Update annotation to index.
:param annotation_id: annotation id
:param question: question
:param tenant_id: tenant id
:param app_id: app id
:param collection_binding_id: embedding binding id
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
logger.info(click.style(f"Start update index for annotation: {annotation_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id, "annotation"
)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,
)
document = Document(
page_content=question, metadata={"annotation_id": annotation_id, "app_id": app_id, "doc_id": annotation_id}
)
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.delete_by_metadata_field("annotation_id", annotation_id)
vector.add_texts([document])
end_at = time.perf_counter()
logger.info(
click.style(
f"Build index successful for annotation: {annotation_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@@ -0,0 +1,196 @@
"""
Celery tasks for async workflow execution.
These tasks handle workflow execution for different subscription tiers
with appropriate retry policies and error handling.
"""
from datetime import UTC, datetime
from typing import Any
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.layers.trigger_post_layer import TriggerPostLayer
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatorUserRole, WorkflowTriggerStatus
from models.model import App, EndUser, Tenant
from models.trigger import WorkflowTriggerLog
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import WorkflowNotFoundError
from services.workflow.entities import (
TriggerData,
WorkflowTaskData,
)
from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity, AsyncWorkflowCFSPlanScheduler
from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkflowSystemStrategy
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
def execute_workflow_professional(task_data_dict: dict[str, Any]):
"""Execute workflow for professional tier with highest priority"""
task_data = WorkflowTaskData.model_validate(task_data_dict)
cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE,
schedule_strategy=AsyncWorkflowSystemStrategy,
granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY,
)
_execute_workflow_common(
task_data,
AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity),
cfs_plan_scheduler_entity,
)
@shared_task(queue=AsyncWorkflowQueue.TEAM_QUEUE)
def execute_workflow_team(task_data_dict: dict[str, Any]):
"""Execute workflow for team tier"""
task_data = WorkflowTaskData.model_validate(task_data_dict)
cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
queue=AsyncWorkflowQueue.TEAM_QUEUE,
schedule_strategy=AsyncWorkflowSystemStrategy,
granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY,
)
_execute_workflow_common(
task_data,
AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity),
cfs_plan_scheduler_entity,
)
@shared_task(queue=AsyncWorkflowQueue.SANDBOX_QUEUE)
def execute_workflow_sandbox(task_data_dict: dict[str, Any]):
"""Execute workflow for free tier with lower retry limit"""
task_data = WorkflowTaskData.model_validate(task_data_dict)
cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
queue=AsyncWorkflowQueue.SANDBOX_QUEUE,
schedule_strategy=AsyncWorkflowSystemStrategy,
granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY,
)
_execute_workflow_common(
task_data,
AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity),
cfs_plan_scheduler_entity,
)
def _build_generator_args(trigger_data: TriggerData) -> dict[str, Any]:
"""Build args passed into WorkflowAppGenerator.generate for Celery executions."""
args: dict[str, Any] = {
"inputs": dict(trigger_data.inputs),
"files": list(trigger_data.files),
SKIP_PREPARE_USER_INPUTS_KEY: True,
}
return args
def _execute_workflow_common(
task_data: WorkflowTaskData,
cfs_plan_scheduler: AsyncWorkflowCFSPlanScheduler,
cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity,
):
"""Execute workflow with common logic and trigger log updates."""
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
# Get trigger log
trigger_log = trigger_log_repo.get_by_id(task_data.workflow_trigger_log_id)
if not trigger_log:
# This should not happen, but handle gracefully
return
# Reconstruct execution data from trigger log
trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data)
# Update status to running
trigger_log.status = WorkflowTriggerStatus.RUNNING
trigger_log_repo.update(trigger_log)
session.commit()
start_time = datetime.now(UTC)
try:
# Get app and workflow models
app_model = session.scalar(select(App).where(App.id == trigger_log.app_id))
if not app_model:
raise WorkflowNotFoundError(f"App not found: {trigger_log.app_id}")
workflow = session.scalar(select(Workflow).where(Workflow.id == trigger_log.workflow_id))
if not workflow:
raise WorkflowNotFoundError(f"Workflow not found: {trigger_log.workflow_id}")
user = _get_user(session, trigger_log)
# Execute workflow using WorkflowAppGenerator
generator = WorkflowAppGenerator()
# Prepare args matching AppGenerateService.generate format
args = _build_generator_args(trigger_data)
# If workflow_id was specified, add it to args
if trigger_data.workflow_id:
args["workflow_id"] = str(trigger_data.workflow_id)
# Execute the workflow with the trigger type
generator.generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
call_depth=0,
triggered_from=trigger_data.trigger_from,
root_node_id=trigger_data.root_node_id,
graph_engine_layers=[
# TODO: Re-enable TimeSliceLayer after the HITL release.
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
],
)
except Exception as e:
# Calculate elapsed time for failed execution
elapsed_time = (datetime.now(UTC) - start_time).total_seconds()
# Update trigger log with failure
trigger_log.status = WorkflowTriggerStatus.FAILED
trigger_log.error = str(e)
trigger_log.finished_at = datetime.now(UTC)
trigger_log.elapsed_time = elapsed_time
trigger_log_repo.update(trigger_log)
# Final failure - no retry logic (simplified like RAG tasks)
session.commit()
def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser:
"""Compose user from trigger log"""
tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id))
if not tenant:
raise ValueError(f"Tenant not found: {trigger_log.tenant_id}")
# Get user from trigger log
if trigger_log.created_by_role == CreatorUserRole.ACCOUNT:
user = session.scalar(select(Account).where(Account.id == trigger_log.created_by))
if user:
user.current_tenant = tenant
else: # CreatorUserRole.END_USER
user = session.scalar(select(EndUser).where(EndUser.id == trigger_log.created_by))
if not user:
raise ValueError(f"User not found: {trigger_log.created_by} (role: {trigger_log.created_by_role})")
return user

View File

@@ -0,0 +1,92 @@
import logging
import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
from models.model import UploadFile
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
"""
Clean document when document deleted.
:param document_ids: document ids
:param dataset_id: dataset id
:param doc_form: doc_form
:param file_ids: file ids
Usage: batch_clean_document_task.delay(document_ids, dataset_id)
"""
logger.info(click.style("Start batch clean documents when documents deleted", fg="green"))
start_at = time.perf_counter()
try:
if not doc_form:
raise ValueError("doc_form is required")
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
db.session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
).delete(synchronize_session=False)
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
try:
if image_file and image_file.key:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
upload_file_id,
)
db.session.delete(image_file)
db.session.delete(segment)
db.session.commit()
if file_ids:
files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
db.session.delete(file)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Cleaned documents when documents deleted failed")
finally:
db.session.close()

View File

@@ -0,0 +1,151 @@
import logging
import tempfile
import time
import uuid
from pathlib import Path
import click
import pandas as pd
from celery import shared_task
from sqlalchemy import func
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import helper
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.model import UploadFile
from services.vector_service import VectorService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def batch_create_segment_to_index_task(
job_id: str,
upload_file_id: str,
dataset_id: str,
document_id: str,
tenant_id: str,
user_id: str,
):
"""
Async batch create segment to index
:param job_id:
:param upload_file_id:
:param dataset_id:
:param document_id:
:param tenant_id:
:param user_id:
Usage: batch_create_segment_to_index_task.delay(job_id, upload_file_id, dataset_id, document_id, tenant_id, user_id)
"""
logger.info(click.style(f"Start batch create segment jobId: {job_id}", fg="green"))
start_at = time.perf_counter()
indexing_cache_key = f"segment_batch_import_{job_id}"
try:
dataset = db.session.get(Dataset, dataset_id)
if not dataset:
raise ValueError("Dataset not exist.")
dataset_document = db.session.get(Document, document_id)
if not dataset_document:
raise ValueError("Document not exist.")
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
raise ValueError("Document is not available.")
upload_file = db.session.get(UploadFile, upload_file_id)
if not upload_file:
raise ValueError("UploadFile not found.")
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if dataset_document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(
texts=[segment["content"] for segment in content]
)
else:
tokens_list = [0] * len(content)
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == dataset_document.id)
.scalar()
)
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if dataset_document.doc_form == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
db.session.add(segment_document)
document_segments.append(segment_document)
assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change
db.session.add(dataset_document)
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
db.session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Segments batch created index failed")
redis_client.setex(indexing_cache_key, 600, "error")
finally:
db.session.close()

View File

@@ -0,0 +1,154 @@
import logging
import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import (
AppDatasetJoin,
Dataset,
DatasetMetadata,
DatasetMetadataBinding,
DatasetProcessRule,
DatasetQuery,
Document,
DocumentSegment,
)
from models.model import UploadFile
logger = logging.getLogger(__name__)
# Add import statement for ValueError
@shared_task(queue="dataset")
def clean_dataset_task(
dataset_id: str,
tenant_id: str,
indexing_technique: str,
index_struct: str,
collection_binding_id: str,
doc_form: str,
):
"""
Clean dataset when dataset deleted.
:param dataset_id: dataset id
:param tenant_id: tenant id
:param indexing_technique: indexing technique
:param index_struct: index struct dict
:param collection_binding_id: collection binding id
:param doc_form: dataset form
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
logger.info(click.style(f"Start clean dataset when dataset deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset = Dataset(
id=dataset_id,
tenant_id=tenant_id,
indexing_technique=indexing_technique,
index_struct=index_struct,
collection_binding_id=collection_binding_id,
)
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
# This ensures all invalid doc_form values are properly handled
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
from core.rag.index_processor.constant.index_type import IndexType
doc_form = IndexType.PARAGRAPH_INDEX
logger.info(
click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
)
# Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
# This ensures Document/Segment deletion can continue even if vector database cleanup fails
try:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
except Exception:
logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
# Continue with document and segment deletion even if vector cleanup fails
logger.info(
click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
)
if documents is None or len(documents) == 0:
logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
else:
logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
for document in documents:
db.session.delete(document)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
try:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
upload_file_id,
)
db.session.delete(image_file)
db.session.delete(segment)
db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
db.session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
# delete dataset metadata
db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
# delete files
if documents:
for document in documents:
try:
if document.data_source_type == "upload_file":
if document.data_source_info:
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
file = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first()
)
if not file:
continue
storage.delete(file.key)
db.session.delete(file)
except Exception:
continue
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", fg="green")
)
except Exception:
# Add rollback to prevent dirty session state in case of exceptions
# This ensures the database session is properly cleaned up
try:
db.session.rollback()
logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
except Exception:
logger.exception("Failed to rollback database session")
logger.exception("Cleaned dataset when dataset deleted failed")
finally:
db.session.close()

View File

@@ -0,0 +1,90 @@
import logging
import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
from models.model import UploadFile
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: str | None):
"""
Clean document when document deleted.
:param document_id: document id
:param dataset_id: dataset id
:param doc_form: doc_form
:param file_id: file id
Usage: clean_document_task.delay(document_id, dataset_id)
"""
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
try:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
upload_file_id,
)
db.session.delete(image_file)
db.session.delete(segment)
db.session.commit()
if file_id:
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
db.session.delete(file)
db.session.commit()
# delete dataset metadata binding
db.session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when document deleted failed")
finally:
db.session.close()

View File

@@ -0,0 +1,60 @@
import logging
import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def clean_notion_document_task(document_ids: list[str], dataset_id: str):
"""
Clean document when document deleted.
:param document_ids: document ids
:param dataset_id: dataset id
Usage: clean_notion_document_task.delay(document_ids, dataset_id)
"""
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
for document_id in document_ids:
document = db.session.query(Document).where(Document.id == document_id).first()
db.session.delete(document)
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when import form notion document deleted failed")
finally:
db.session.close()

View File

@@ -0,0 +1,99 @@
import logging
import time
import click
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = None):
"""
Async create segment to index
:param segment_id:
:param keywords:
Usage: create_segment_to_index_task.delay(segment_id)
"""
logger.info(click.style(f"Start create segment to index: {segment_id}", fg="green"))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
db.session.close()
return
if segment.status != "waiting":
db.session.close()
return
indexing_cache_key = f"segment_{segment.id}_indexing"
try:
# update segment status to indexing
db.session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: naive_utc_now(),
}
)
db.session.commit()
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
dataset = segment.dataset
if not dataset:
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
return
dataset_document = segment.document
if not dataset_document:
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, [document])
# update segment to completed
db.session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "completed",
DocumentSegment.completed_at: naive_utc_now(),
}
)
db.session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("create segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()

View File

@@ -0,0 +1,171 @@
import logging
import time
import click
from celery import shared_task # type: ignore
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@shared_task(queue="dataset")
def deal_dataset_index_update_task(dataset_id: str, action: str):
"""
Async deal dataset from index
:param dataset_id: dataset_id
:param action: action
Usage: deal_dataset_index_update_task.delay(dataset_id, action)
"""
logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "upgrade":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
# save vector index
# clean keywords
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
elif action == "update":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
# add new index
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
# clean index
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
end_at = time.perf_counter()
logging.info(
click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
)
except Exception:
logging.exception("Deal dataset vector index failed")
finally:
db.session.close()

View File

@@ -0,0 +1,169 @@
import logging
import time
from typing import Literal
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]):
"""
Async deal dataset from index
:param dataset_id: dataset_id
:param action: action
Usage: deal_dataset_vector_index_task.delay(dataset_id, action)
"""
logger.info(click.style(f"Start deal dataset vector index: {dataset_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove":
index_processor.clean(dataset, None, with_keywords=False)
elif action == "add":
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
).all()
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
elif action == "update":
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
).all()
# add new index
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
# clean index
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
end_at = time.perf_counter()
logger.info(click.style(f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("Deal dataset vector index failed")
finally:
db.session.close()

View File

@@ -0,0 +1,26 @@
import logging
from celery import shared_task
from extensions.ext_database import db
from models import Account
from services.billing_service import BillingService
from tasks.mail_account_deletion_task import send_deletion_success_task
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_account_task(account_id):
account = db.session.query(Account).where(Account.id == account_id).first()
try:
BillingService.delete_account(account_id)
except Exception:
logger.exception("Failed to delete account %s from billing service.", account_id)
raise
if not account:
logger.error("Account %s not found.", account_id)
return
# send success email
send_deletion_success_task.delay(account.email)

View File

@@ -0,0 +1,70 @@
import logging
import time
import click
from celery import shared_task
from extensions.ext_database import db
from models import ConversationVariable
from models.model import Message, MessageAnnotation, MessageFeedback
from models.tools import ToolConversationVariables, ToolFile
from models.web import PinnedConversation
logger = logging.getLogger(__name__)
@shared_task(queue="conversation")
def delete_conversation_related_data(conversation_id: str):
"""
Delete related data conversation in correct order from datatbase to respect foreign key constraints
Args:
conversation_id: conversation Id
"""
logger.info(
click.style(f"Starting to delete conversation data from db for conversation_id {conversation_id}", fg="green")
)
start_at = time.perf_counter()
try:
db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(ToolConversationVariables).where(
ToolConversationVariables.conversation_id == conversation_id
).delete(synchronize_session=False)
db.session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
db.session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Succeeded cleaning data from db for conversation_id {conversation_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception as e:
logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
db.session.rollback()
raise e
finally:
db.session.close()

View File

@@ -0,0 +1,58 @@
import logging
import time
import click
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_segment_from_index_task(
index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None
):
"""
Async Remove segment from index
:param index_node_ids:
:param dataset_id:
:param document_id:
Usage: delete_segment_from_index_task.delay(index_node_ids, dataset_id, document_id)
"""
logger.info(click.style("Start delete segment from index", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
return
dataset_document = db.session.query(Document).where(Document.id == document_id).first()
if not dataset_document:
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info("Document not in valid state for index operations, skipping")
return
doc_form = dataset_document.doc_form
# Proceed with index cleanup using the index_node_ids directly
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset,
index_node_ids,
with_keywords=True,
delete_child_chunks=True,
precomputed_child_node_ids=child_node_ids,
)
end_at = time.perf_counter()
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("delete segment from index failed")
finally:
db.session.close()

View File

@@ -0,0 +1,68 @@
import logging
import time
import click
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def disable_segment_from_index_task(segment_id: str):
"""
Async disable segment from index
:param segment_id:
Usage: disable_segment_from_index_task.delay(segment_id)
"""
logger.info(click.style(f"Start disable segment from index: {segment_id}", fg="green"))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
db.session.close()
return
if segment.status != "completed":
logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
db.session.close()
return
indexing_cache_key = f"segment_{segment.id}_indexing"
try:
dataset = segment.dataset
if not dataset:
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
return
dataset_document = segment.document
if not dataset_document:
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.clean(dataset, [segment.index_node_id])
end_at = time.perf_counter()
logger.info(click.style(f"Segment removed from index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("remove segment from index failed")
segment.enabled = True
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()

View File

@@ -0,0 +1,84 @@
import logging
import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def disable_segments_from_index_task(segment_ids: list, dataset_id: str, document_id: str):
"""
Async disable segments from index
:param segment_ids: list of segment ids
:param dataset_id: dataset id
:param document_id: document id
Usage: disable_segments_from_index_task.delay(segment_ids, dataset_id, document_id)
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
db.session.close()
return
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
if not dataset_document:
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
db.session.close()
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
db.session.close()
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
).all()
if not segments:
db.session.close()
return
try:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
end_at = time.perf_counter()
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
except Exception:
# update segment error msg
db.session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"disabled_at": None,
"disabled_by": None,
"enabled": True,
}
)
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
redis_client.delete(indexing_cache_key)
db.session.close()

View File

@@ -0,0 +1,124 @@
import logging
import time
import click
import sqlalchemy as sa
from celery import shared_task
from sqlalchemy import select
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.source import DataSourceOauthBinding
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def document_indexing_sync_task(dataset_id: str, document_id: str):
"""
Async update document
:param dataset_id:
:param document_id:
Usage: document_indexing_sync_task.delay(dataset_id, document_id)
"""
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
data_source_info = document.data_source_info_dict
if document.data_source_type == "notion_import":
if (
not data_source_info
or "notion_page_id" not in data_source_info
or "notion_workspace_id" not in data_source_info
):
raise ValueError("no notion page found")
workspace_id = data_source_info["notion_workspace_id"]
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
sa.and_(
DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
.first()
)
if not data_source_binding:
raise ValueError("Data source binding not found.")
loader = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=data_source_binding.access_token,
tenant_id=document.tenant_id,
)
last_edited_time = loader.get_notion_last_edited_time()
# check the page is updated
if last_edited_time != page_edited_time:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
db.session.commit()
# delete all document segment and index
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
end_at = time.perf_counter()
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
finally:
db.session.close()

View File

@@ -0,0 +1,168 @@
import logging
import time
from collections.abc import Callable, Sequence
import click
from celery import shared_task
from configs import dify_config
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def document_indexing_task(dataset_id: str, document_ids: list):
"""
Async process document
:param dataset_id:
:param document_ids:
.. warning:: TO BE DEPRECATED
This function will be deprecated and removed in a future version.
Use normal_document_indexing_task or priority_document_indexing_task instead.
Usage: document_indexing_task.delay(dataset_id, document_ids)
"""
logger.warning("document indexing legacy mode received: %s - %s", dataset_id, document_ids)
_document_indexing(dataset_id, document_ids)
def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
"""
Process document for tasks
:param dataset_id:
:param document_ids:
Usage: _document_indexing(dataset_id, document_ids)
"""
documents = []
start_at = time.perf_counter()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
db.session.close()
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
for document_id in document_ids:
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
db.session.close()
return
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
finally:
db.session.close()
def _document_indexing_with_tenant_queue(
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
):
try:
_document_indexing(dataset_id, document_ids)
except Exception:
logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id)
finally:
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks)
if next_tasks:
for next_task in next_tasks:
document_task = DocumentTask(**next_task)
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=document_task.tenant_id,
dataset_id=document_task.dataset_id,
document_ids=document_task.document_ids,
)
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()
@shared_task(queue="dataset")
def normal_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
"""
Async process document
:param tenant_id:
:param dataset_id:
:param document_ids:
Usage: normal_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
"""
logger.info("normal document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, normal_document_indexing_task)
@shared_task(queue="priority_dataset")
def priority_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
"""
Priority async process document
:param tenant_id:
:param dataset_id:
:param document_ids:
Usage: priority_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
"""
logger.info("priority document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, priority_document_indexing_task)

View File

@@ -0,0 +1,81 @@
import logging
import time
import click
from celery import shared_task
from sqlalchemy import select
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def document_indexing_update_task(dataset_id: str, document_id: str):
"""
Async update document
:param dataset_id:
:param document_id:
Usage: document_indexing_update_task.delay(dataset_id, document_id)
"""
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
db.session.commit()
# delete all document segment and index
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
finally:
db.session.close()

View File

@@ -0,0 +1,112 @@
import logging
import time
import click
from celery import shared_task
from sqlalchemy import select
from configs import dify_config
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
"""
Async process document
:param dataset_id:
:param document_ids:
Usage: duplicate_document_indexing_task.delay(dataset_id, document_ids)
"""
documents = []
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
db.session.close()
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
current = int(getattr(vector_space, "size", 0) or 0)
limit = int(getattr(vector_space, "limit", 0) or 0)
if limit > 0 and (current + count) > limit:
raise ValueError(
"Your total number of documents plus the number of uploads have exceeded the limit of "
"your subscription."
)
except Exception as e:
for document_id in document_ids:
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
return
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
# clean old data
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
finally:
db.session.close()

View File

@@ -0,0 +1,100 @@
import logging
import time
import click
from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def enable_segment_to_index_task(segment_id: str):
"""
Async enable segment to index
:param segment_id:
Usage: enable_segment_to_index_task.delay(segment_id)
"""
logger.info(click.style(f"Start enable segment to index: {segment_id}", fg="green"))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
db.session.close()
return
if segment.status != "completed":
logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
db.session.close()
return
indexing_cache_key = f"segment_{segment.id}_indexing"
try:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
dataset = segment.dataset
if not dataset:
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
return
dataset_document = segment.document
if not dataset_document:
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
# save vector index
index_processor.load(dataset, [document])
end_at = time.perf_counter()
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("enable segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()

View File

@@ -0,0 +1,116 @@
import logging
import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_id: str):
"""
Async enable segments to index
:param segment_ids: list of segment ids
:param dataset_id: dataset id
:param document_id: document id
Usage: enable_segments_to_index_task.delay(segment_ids, dataset_id, document_id)
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
return
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
if not dataset_document:
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
db.session.close()
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
db.session.close()
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
).all()
if not segments:
logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
db.session.close()
return
try:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# save vector index
index_processor.load(dataset, documents)
end_at = time.perf_counter()
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("enable segments to index failed")
# update segment error msg
db.session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"error": str(e),
"status": "error",
"disabled_at": naive_utc_now(),
"enabled": False,
}
)
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
redis_client.delete(indexing_cache_key)
db.session.close()

View File

@@ -0,0 +1,86 @@
import logging
import time
import click
from celery import shared_task
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_deletion_success_task(to: str, language: str = "en-US"):
"""
Send account deletion success email with internationalization support.
Args:
to: Recipient email address
language: Language code for email localization
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start send account deletion success email to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.ACCOUNT_DELETION_SUCCESS,
language_code=language,
to=to,
template_context={
"to": to,
"email": to,
},
)
end_at = time.perf_counter()
logger.info(
click.style(f"Send account deletion success email to {to}: latency: {end_at - start_at}", fg="green")
)
except Exception:
logger.exception("Send account deletion success email to %s failed", to)
@shared_task(queue="mail")
def send_account_deletion_verification_code(to: str, code: str, language: str = "en-US"):
"""
Send account deletion verification code email with internationalization support.
Args:
to: Recipient email address
code: Verification code
language: Language code for email localization
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start send account deletion verification code email to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.ACCOUNT_DELETION_VERIFICATION,
language_code=language,
to=to,
template_context={
"to": to,
"code": code,
},
)
end_at = time.perf_counter()
logger.info(
click.style(
"Send account deletion verification code email to {} succeeded: latency: {}".format(
to, end_at - start_at
),
fg="green",
)
)
except Exception:
logger.exception("Send account deletion verification code email to %s failed", to)

View File

@@ -0,0 +1,80 @@
import logging
import time
import click
from celery import shared_task
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_change_mail_task(language: str, to: str, code: str, phase: str):
"""
Send change email notification with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
code: Email verification code
phase: Change email phase ('old_email' or 'new_email')
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start change email mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_change_email(
language_code=language,
to=to,
code=code,
phase=phase,
)
end_at = time.perf_counter()
logger.info(click.style(f"Send change email mail to {to} succeeded: latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("Send change email mail to %s failed", to)
@shared_task(queue="mail")
def send_change_mail_completed_notification_task(language: str, to: str):
"""
Send change email completed notification with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start change email completed notify mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.CHANGE_EMAIL_COMPLETED,
language_code=language,
to=to,
template_context={
"to": to,
"email": to,
},
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Send change email completed mail to {to} succeeded: latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Send change email completed mail to %s failed", to)

View File

@@ -0,0 +1,46 @@
import logging
import time
import click
from celery import shared_task
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_email_code_login_mail_task(language: str, to: str, code: str):
"""
Send email code login email with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
code: Email verification code
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start email code login mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.EMAIL_CODE_LOGIN,
language_code=language,
to=to,
template_context={
"to": to,
"code": code,
},
)
end_at = time.perf_counter()
logger.info(
click.style(f"Send email code login mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
)
except Exception:
logger.exception("Send email code login mail to %s failed", to)

View File

@@ -0,0 +1,61 @@
import logging
import time
from collections.abc import Mapping
from typing import Any
import click
from celery import shared_task
from flask import render_template_string
from jinja2.runtime import Context
from jinja2.sandbox import ImmutableSandboxedEnvironment
from configs import dify_config
from configs.feature import TemplateMode
from extensions.ext_mail import mail
from libs.email_i18n import get_email_i18n_service
logger = logging.getLogger(__name__)
class SandboxedEnvironment(ImmutableSandboxedEnvironment):
def __init__(self, timeout: int, *args: Any, **kwargs: Any):
self._timeout_time = time.time() + timeout
super().__init__(*args, **kwargs)
def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any:
if time.time() > self._timeout_time:
raise TimeoutError("Template rendering timeout")
return super().call(context, obj, *args, **kwargs)
def _render_template_with_strategy(body: str, substitutions: Mapping[str, str]) -> str:
mode = dify_config.MAIL_TEMPLATING_MODE
timeout = dify_config.MAIL_TEMPLATING_TIMEOUT
if mode == TemplateMode.UNSAFE:
return render_template_string(body, **substitutions)
if mode == TemplateMode.SANDBOX:
tmpl = SandboxedEnvironment(timeout=timeout).from_string(body)
return tmpl.render(substitutions)
if mode == TemplateMode.DISABLED:
return body
raise ValueError(f"Unsupported mail templating mode: {mode}")
@shared_task(queue="mail")
def send_inner_email_task(to: list[str], subject: str, body: str, substitutions: Mapping[str, str]):
if not mail.is_inited():
return
logger.info(click.style(f"Start enterprise mail to {to} with subject {subject}", fg="green"))
start_at = time.perf_counter()
try:
html_content = _render_template_with_strategy(body, substitutions)
email_service = get_email_i18n_service()
email_service.send_raw_email(to=to, subject=subject, html_content=html_content)
end_at = time.perf_counter()
logger.info(click.style(f"Send enterprise mail to {to} succeeded: latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("Send enterprise mail to %s failed", to)

View File

@@ -0,0 +1,50 @@
import logging
import time
import click
from celery import shared_task
from configs import dify_config
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str):
"""
Send invite member email with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
token: Invitation token
inviter_name: Name of the person sending the invitation
workspace_name: Name of the workspace
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start send invite member mail to {to} in workspace {workspace_name}", fg="green"))
start_at = time.perf_counter()
try:
url = f"{dify_config.CONSOLE_WEB_URL}/activate?token={token}"
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.INVITE_MEMBER,
language_code=language,
to=to,
template_context={
"to": to,
"inviter_name": inviter_name,
"workspace_name": workspace_name,
"url": url,
},
)
end_at = time.perf_counter()
logger.info(click.style(f"Send invite member mail to {to} succeeded: latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("Send invite member mail to %s failed", to)

View File

@@ -0,0 +1,131 @@
import logging
import time
import click
from celery import shared_task
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspace: str):
"""
Send owner transfer confirmation email with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
code: Verification code
workspace: Workspace name
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start owner transfer confirm mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.OWNER_TRANSFER_CONFIRM,
language_code=language,
to=to,
template_context={
"to": to,
"code": code,
"WorkspaceName": workspace,
},
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Send owner transfer confirm mail to {to} succeeded: latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("owner transfer confirm email mail to %s failed", to)
@shared_task(queue="mail")
def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: str, new_owner_email: str):
"""
Send old owner transfer notification email with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
workspace: Workspace name
new_owner_email: New owner email address
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start old owner transfer notify mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.OWNER_TRANSFER_OLD_NOTIFY,
language_code=language,
to=to,
template_context={
"to": to,
"WorkspaceName": workspace,
"NewOwnerEmail": new_owner_email,
},
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Send old owner transfer notify mail to {to} succeeded: latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("old owner transfer notify email mail to %s failed", to)
@shared_task(queue="mail")
def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: str):
"""
Send new owner transfer notification email with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
workspace: Workspace name
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start new owner transfer notify mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.OWNER_TRANSFER_NEW_NOTIFY,
language_code=language,
to=to,
template_context={
"to": to,
"WorkspaceName": workspace,
},
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Send new owner transfer notify mail to {to} succeeded: latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("new owner transfer notify email mail to %s failed", to)

View File

@@ -0,0 +1,87 @@
import logging
import time
import click
from celery import shared_task
from configs import dify_config
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_email_register_mail_task(language: str, to: str, code: str) -> None:
"""
Send email register email with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
code: Email register code
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start email register mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.EMAIL_REGISTER,
language_code=language,
to=to,
template_context={
"to": to,
"code": code,
},
)
end_at = time.perf_counter()
logger.info(
click.style(f"Send email register mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
)
except Exception:
logger.exception("Send email register mail to %s failed", to)
@shared_task(queue="mail")
def send_email_register_mail_task_when_account_exist(language: str, to: str, account_name: str) -> None:
"""
Send email register email with internationalization support when account exist.
Args:
language: Language code for email localization
to: Recipient email address
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start email register mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
login_url = f"{dify_config.CONSOLE_WEB_URL}/signin"
reset_password_url = f"{dify_config.CONSOLE_WEB_URL}/reset-password"
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST,
language_code=language,
to=to,
template_context={
"to": to,
"login_url": login_url,
"reset_password_url": reset_password_url,
"account_name": account_name,
},
)
end_at = time.perf_counter()
logger.info(
click.style(f"Send email register mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
)
except Exception:
logger.exception("Send email register mail to %s failed", to)

View File

@@ -0,0 +1,91 @@
import logging
import time
import click
from celery import shared_task
from configs import dify_config
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_reset_password_mail_task(language: str, to: str, code: str):
"""
Send reset password email with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
code: Reset password code
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start password reset mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.RESET_PASSWORD,
language_code=language,
to=to,
template_context={
"to": to,
"code": code,
},
)
end_at = time.perf_counter()
logger.info(
click.style(f"Send password reset mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
)
except Exception:
logger.exception("Send password reset mail to %s failed", to)
@shared_task(queue="mail")
def send_reset_password_mail_task_when_account_not_exist(language: str, to: str, is_allow_register: bool) -> None:
"""
Send reset password email with internationalization support when account not exist.
Args:
language: Language code for email localization
to: Recipient email address
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start password reset mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
if is_allow_register:
sign_up_url = f"{dify_config.CONSOLE_WEB_URL}/signup"
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST,
language_code=language,
to=to,
template_context={
"to": to,
"sign_up_url": sign_up_url,
},
)
else:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER,
language_code=language,
to=to,
)
end_at = time.perf_counter()
logger.info(
click.style(f"Send password reset mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
)
except Exception:
logger.exception("Send password reset mail to %s failed", to)

View File

@@ -0,0 +1,55 @@
import json
import logging
from celery import shared_task
from flask import current_app
from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY
from core.ops.entities.trace_entity import trace_info_info_map
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from models.model import Message
from models.workflow import WorkflowRun
logger = logging.getLogger(__name__)
@shared_task(queue="ops_trace")
def process_trace_tasks(file_info):
"""
Async process trace tasks
Usage: process_trace_tasks.delay(tasks_data)
"""
from core.ops.ops_trace_manager import OpsTraceManager
app_id = file_info.get("app_id")
file_id = file_info.get("file_id")
file_path = f"{OPS_FILE_PATH}{app_id}/{file_id}.json"
file_data = json.loads(storage.load(file_path))
trace_info = file_data.get("trace_info")
trace_info_type = file_data.get("trace_info_type")
trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
if trace_info.get("message_data"):
trace_info["message_data"] = Message.from_dict(data=trace_info["message_data"])
if trace_info.get("workflow_data"):
trace_info["workflow_data"] = WorkflowRun.from_dict(data=trace_info["workflow_data"])
if trace_info.get("documents"):
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
try:
if trace_instance:
with current_app.app_context():
trace_type = trace_info_info_map.get(trace_info_type)
if trace_type:
trace_info = trace_type(**trace_info)
trace_instance.trace(trace_info)
logger.info("Processing trace tasks success, app_id: %s", app_id)
except Exception as e:
logger.info("error:\n\n\n%s\n\n\n\n", e)
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
redis_client.incr(failed_key)
logger.info("Processing trace tasks failed, app_id: %s", app_id)
finally:
storage.delete(file_path)

View File

@@ -0,0 +1,234 @@
import json
import operator
import typing
import click
from celery import shared_task
from core.helper import marketplace
from core.helper.marketplace import MarketplacePluginDeclaration
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_redis import redis_client
from models.account import TenantPluginAutoUpgradeStrategy
RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3
CACHE_REDIS_KEY_PREFIX = "plugin_autoupgrade_check_task:cached_plugin_manifests:"
CACHE_REDIS_TTL = 60 * 15 # 15 minutes
def _get_redis_cache_key(plugin_id: str) -> str:
"""Generate Redis cache key for plugin manifest."""
return f"{CACHE_REDIS_KEY_PREFIX}{plugin_id}"
def _get_cached_manifest(plugin_id: str) -> typing.Union[MarketplacePluginDeclaration, None, bool]:
"""
Get cached plugin manifest from Redis.
Returns:
- MarketplacePluginDeclaration: if found in cache
- None: if cached as not found (marketplace returned no result)
- False: if not in cache at all
"""
try:
key = _get_redis_cache_key(plugin_id)
cached_data = redis_client.get(key)
if cached_data is None:
return False
cached_json = json.loads(cached_data)
if cached_json is None:
return None
return MarketplacePluginDeclaration.model_validate(cached_json)
except Exception:
return False
def _set_cached_manifest(plugin_id: str, manifest: typing.Union[MarketplacePluginDeclaration, None]) -> None:
"""
Cache plugin manifest in Redis.
Args:
plugin_id: The plugin ID
manifest: The manifest to cache, or None if not found in marketplace
"""
try:
key = _get_redis_cache_key(plugin_id)
if manifest is None:
# Cache the fact that this plugin was not found
redis_client.setex(key, CACHE_REDIS_TTL, json.dumps(None))
else:
# Cache the manifest data
redis_client.setex(key, CACHE_REDIS_TTL, manifest.model_dump_json())
except Exception:
# If Redis fails, continue without caching
# traceback.print_exc()
pass
def marketplace_batch_fetch_plugin_manifests(
plugin_ids_plain_list: list[str],
) -> list[MarketplacePluginDeclaration]:
"""Fetch plugin manifests with Redis caching support."""
cached_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {}
not_cached_plugin_ids: list[str] = []
# Check Redis cache for each plugin
for plugin_id in plugin_ids_plain_list:
cached_result = _get_cached_manifest(plugin_id)
if cached_result is False:
# Not in cache, need to fetch
not_cached_plugin_ids.append(plugin_id)
else:
# Either found manifest or cached as None (not found in marketplace)
# At this point, cached_result is either MarketplacePluginDeclaration or None
if isinstance(cached_result, bool):
# This should never happen due to the if condition above, but for type safety
continue
cached_manifests[plugin_id] = cached_result
# Fetch uncached plugins from marketplace
if not_cached_plugin_ids:
manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_cached_plugin_ids)
# Cache the fetched manifests
for manifest in manifests:
cached_manifests[manifest.plugin_id] = manifest
_set_cached_manifest(manifest.plugin_id, manifest)
# Cache plugins that were not found in marketplace
fetched_plugin_ids = {manifest.plugin_id for manifest in manifests}
for plugin_id in not_cached_plugin_ids:
if plugin_id not in fetched_plugin_ids:
cached_manifests[plugin_id] = None
_set_cached_manifest(plugin_id, None)
# Build result list from cached manifests
result: list[MarketplacePluginDeclaration] = []
for plugin_id in plugin_ids_plain_list:
cached_manifest: typing.Union[MarketplacePluginDeclaration, None] = cached_manifests.get(plugin_id)
if cached_manifest is not None:
result.append(cached_manifest)
return result
@shared_task(queue="plugin")
def process_tenant_plugin_autoupgrade_check_task(
tenant_id: str,
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting,
upgrade_time_of_day: int,
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode,
exclude_plugins: list[str],
include_plugins: list[str],
):
try:
manager = PluginInstaller()
click.echo(
click.style(
f"Checking upgradable plugin for tenant: {tenant_id}",
fg="green",
)
)
if strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED:
return
# get plugin_ids to check
plugin_ids: list[tuple[str, str, str]] = [] # plugin_id, version, unique_identifier
click.echo(click.style(f"Upgrade mode: {upgrade_mode}", fg="green"))
if upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL and include_plugins:
all_plugins = manager.list_plugins(tenant_id)
for plugin in all_plugins:
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id in include_plugins:
plugin_ids.append(
(
plugin.plugin_id,
plugin.version,
plugin.plugin_unique_identifier,
)
)
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE:
# get all plugins and remove excluded plugins
all_plugins = manager.list_plugins(tenant_id)
plugin_ids = [
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
for plugin in all_plugins
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id not in exclude_plugins
]
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL:
all_plugins = manager.list_plugins(tenant_id)
plugin_ids = [
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
for plugin in all_plugins
if plugin.source == PluginInstallationSource.Marketplace
]
if not plugin_ids:
return
plugin_ids_plain_list = [plugin_id for plugin_id, _, _ in plugin_ids]
manifests = marketplace_batch_fetch_plugin_manifests(plugin_ids_plain_list)
if not manifests:
return
for manifest in manifests:
for plugin_id, version, original_unique_identifier in plugin_ids:
if manifest.plugin_id != plugin_id:
continue
try:
current_version = version
latest_version = manifest.latest_version
def fix_only_checker(latest_version: str, current_version: str):
latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
current_version_tuple = tuple(int(val) for val in current_version.split("."))
if (
latest_version_tuple[0] == current_version_tuple[0]
and latest_version_tuple[1] == current_version_tuple[1]
):
return latest_version_tuple[2] != current_version_tuple[2]
return False
version_checker = {
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: operator.ne,
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
}
if version_checker[strategy_setting](latest_version, current_version):
# execute upgrade
new_unique_identifier = manifest.latest_package_identifier
marketplace.record_install_plugin_event(new_unique_identifier)
click.echo(
click.style(
f"Upgrade plugin: {original_unique_identifier} -> {new_unique_identifier}",
fg="green",
)
)
_ = manager.upgrade_plugin(
tenant_id,
original_unique_identifier,
new_unique_identifier,
PluginInstallationSource.Marketplace,
{
"plugin_unique_identifier": new_unique_identifier,
},
)
except Exception as e:
click.echo(click.style(f"Error when upgrading plugin: {e}", fg="red"))
# traceback.print_exc()
break
except Exception as e:
click.echo(click.style(f"Error when checking upgradable plugin: {e}", fg="red"))
# traceback.print_exc()
return

View File

@@ -0,0 +1,187 @@
import contextvars
import json
import logging
import time
import uuid
from collections.abc import Mapping
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import click
from celery import shared_task # type: ignore
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from core.repositories.factory import DifyCoreRepositoryFactory
from extensions.ext_database import db
from models import Account, Tenant
from models.dataset import Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
logger = logging.getLogger(__name__)
@shared_task(queue="priority_pipeline")
def priority_rag_pipeline_run_task(
rag_pipeline_invoke_entities_file_id: str,
tenant_id: str,
):
"""
Async Run rag pipeline task using high priority queue.
:param rag_pipeline_invoke_entities_file_id: File ID containing serialized RAG pipeline invoke entities
:param tenant_id: Tenant ID for the pipeline execution
"""
# run with threading, thread pool size is 10
try:
start_at = time.perf_counter()
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
rag_pipeline_invoke_entities_file_id
)
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
# Get Flask app object for thread context
flask_app = current_app._get_current_object() # type: ignore
with ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
# Submit task to thread pool with Flask app
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
futures.append(future)
# Wait for all tasks to complete
for future in futures:
try:
future.result() # This will raise any exceptions that occurred in the thread
except Exception:
logging.exception("Error in pipeline task")
end_at = time.perf_counter()
logging.info(
click.style(
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
)
)
except Exception:
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline")
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids)
if next_file_ids:
for next_file_id in next_file_ids:
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
"""Run a single RAG pipeline task within Flask app context."""
# Create Flask application context for this thread
with flask_app.app_context():
try:
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity)
user_id = rag_pipeline_invoke_entity_model.user_id
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
streaming = rag_pipeline_invoke_entity_model.streaming
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
with Session(db.engine, expire_on_commit=False) as session:
# Load required entities
account = session.query(Account).where(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
if workflow_execution_id is None:
workflow_execution_id = str(uuid.uuid4())
# Create application generate entity from dict
entity = RagPipelineGenerateEntity.model_validate(application_generate_entity)
# Create workflow repositories
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
)
workflow_node_execution_repository = (
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
)
# Set the user directly in g for preserve_flask_contexts
g._login_user = account
# Copy context for passing to pipeline generator
context = contextvars.copy_context()
# Direct execution without creating another thread
# Since we're already in a thread pool, no need for nested threading
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
pipeline_generator = PipelineGenerator()
# Using protected method intentionally for async execution
pipeline_generator._generate( # type: ignore[attr-defined]
flask_app=flask_app,
context=context,
pipeline=pipeline,
workflow_id=workflow_id,
user=account,
application_generate_entity=entity,
invoke_from=InvokeFrom.PUBLISHED,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
except Exception:
logging.exception("Error in priority pipeline task")
raise

View File

@@ -0,0 +1,187 @@
import contextvars
import json
import logging
import time
import uuid
from collections.abc import Mapping
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import click
from celery import shared_task # type: ignore
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from core.repositories.factory import DifyCoreRepositoryFactory
from extensions.ext_database import db
from models import Account, Tenant
from models.dataset import Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
logger = logging.getLogger(__name__)
@shared_task(queue="pipeline")
def rag_pipeline_run_task(
rag_pipeline_invoke_entities_file_id: str,
tenant_id: str,
):
"""
Async Run rag pipeline task using regular priority queue.
:param rag_pipeline_invoke_entities_file_id: File ID containing serialized RAG pipeline invoke entities
:param tenant_id: Tenant ID for the pipeline execution
"""
# run with threading, thread pool size is 10
try:
start_at = time.perf_counter()
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
rag_pipeline_invoke_entities_file_id
)
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
# Get Flask app object for thread context
flask_app = current_app._get_current_object() # type: ignore
with ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
# Submit task to thread pool with Flask app
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
futures.append(future)
# Wait for all tasks to complete
for future in futures:
try:
future.result() # This will raise any exceptions that occurred in the thread
except Exception:
logging.exception("Error in pipeline task")
end_at = time.perf_counter()
logging.info(
click.style(
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
)
)
except Exception:
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline")
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids)
if next_file_ids:
for next_file_id in next_file_ids:
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
"""Run a single RAG pipeline task within Flask app context."""
# Create Flask application context for this thread
with flask_app.app_context():
try:
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity)
user_id = rag_pipeline_invoke_entity_model.user_id
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
streaming = rag_pipeline_invoke_entity_model.streaming
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
with Session(db.engine) as session:
# Load required entities
account = session.query(Account).where(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
if workflow_execution_id is None:
workflow_execution_id = str(uuid.uuid4())
# Create application generate entity from dict
entity = RagPipelineGenerateEntity.model_validate(application_generate_entity)
# Create workflow repositories
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
)
workflow_node_execution_repository = (
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
)
# Set the user directly in g for preserve_flask_contexts
g._login_user = account
# Copy context for passing to pipeline generator
context = contextvars.copy_context()
# Direct execution without creating another thread
# Since we're already in a thread pool, no need for nested threading
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
pipeline_generator = PipelineGenerator()
# Using protected method intentionally for async execution
pipeline_generator._generate( # type: ignore[attr-defined]
flask_app=flask_app,
context=context,
pipeline=pipeline,
workflow_id=workflow_id,
user=account,
application_generate_entity=entity,
invoke_from=InvokeFrom.PUBLISHED,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
except Exception:
logging.exception("Error in pipeline task")
raise

View File

@@ -0,0 +1,48 @@
import logging
import time
import click
from celery import shared_task
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from extensions.ext_database import db
from models.dataset import Document
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def recover_document_indexing_task(dataset_id: str, document_id: str):
"""
Async recover document
:param dataset_id:
:param document_id:
Usage: recover_document_indexing_task.delay(dataset_id, document_id)
"""
logger.info(click.style(f"Recover document: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
try:
indexing_runner = IndexingRunner()
if document.indexing_status in {"waiting", "parsing", "cleaning"}:
indexing_runner.run([document])
elif document.indexing_status == "splitting":
indexing_runner.run_in_splitting_status(document)
elif document.indexing_status == "indexing":
indexing_runner.run_in_indexing_status(document)
end_at = time.perf_counter()
logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("recover_document_indexing_task failed, document_id: %s", document_id)
finally:
db.session.close()

View File

@@ -0,0 +1,577 @@
import logging
import time
from collections.abc import Callable
import click
import sqlalchemy as sa
from celery import shared_task
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from models import (
ApiToken,
AppAnnotationHitHistory,
AppAnnotationSetting,
AppDatasetJoin,
AppMCPServer,
AppModelConfig,
AppTrigger,
Conversation,
EndUser,
InstalledApp,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
RecommendedApp,
Site,
TagBinding,
TraceAppConfig,
WorkflowSchedulePlan,
)
from models.tools import WorkflowToolProvider
from models.trigger import WorkflowPluginTrigger, WorkflowTriggerLog, WorkflowWebhookTrigger
from models.web import PinnedConversation, SavedMessage
from models.workflow import (
ConversationVariable,
Workflow,
WorkflowAppLog,
)
from repositories.factory import DifyAPIRepositoryFactory
logger = logging.getLogger(__name__)
@shared_task(queue="app_deletion", bind=True, max_retries=3)
def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
logger.info(click.style(f"Start deleting app and related data: {tenant_id}:{app_id}", fg="green"))
start_at = time.perf_counter()
try:
# Delete related data
_delete_app_model_configs(tenant_id, app_id)
_delete_app_site(tenant_id, app_id)
_delete_app_mcp_servers(tenant_id, app_id)
_delete_app_api_tokens(tenant_id, app_id)
_delete_installed_apps(tenant_id, app_id)
_delete_recommended_apps(tenant_id, app_id)
_delete_app_annotation_data(tenant_id, app_id)
_delete_app_dataset_joins(tenant_id, app_id)
_delete_app_workflows(tenant_id, app_id)
_delete_app_workflow_runs(tenant_id, app_id)
_delete_app_workflow_node_executions(tenant_id, app_id)
_delete_app_workflow_app_logs(tenant_id, app_id)
_delete_app_conversations(tenant_id, app_id)
_delete_app_messages(tenant_id, app_id)
_delete_workflow_tool_providers(tenant_id, app_id)
_delete_app_tag_bindings(tenant_id, app_id)
_delete_end_users(tenant_id, app_id)
_delete_trace_app_configs(tenant_id, app_id)
_delete_conversation_variables(app_id=app_id)
_delete_draft_variables(app_id)
_delete_app_triggers(tenant_id, app_id)
_delete_workflow_plugin_triggers(tenant_id, app_id)
_delete_workflow_webhook_triggers(tenant_id, app_id)
_delete_workflow_schedule_plans(tenant_id, app_id)
_delete_workflow_trigger_logs(tenant_id, app_id)
end_at = time.perf_counter()
logger.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green"))
except SQLAlchemyError as e:
logger.exception(click.style(f"Database error occurred while deleting app {app_id} and related data", fg="red"))
raise self.retry(exc=e, countdown=60) # Retry after 60 seconds
except Exception as e:
logger.exception(click.style(f"Error occurred while deleting app {app_id} and related data", fg="red"))
raise self.retry(exc=e, countdown=60) # Retry after 60 seconds
def _delete_app_model_configs(tenant_id: str, app_id: str):
def del_model_config(model_config_id: str):
db.session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_model_configs where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_model_config,
"app model config",
)
def _delete_app_site(tenant_id: str, app_id: str):
def del_site(site_id: str):
db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
_delete_records(
"""select id from sites where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_site,
"site",
)
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def del_mcp_server(mcp_server_id: str):
db.session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_mcp_server,
"app mcp server",
)
def _delete_app_api_tokens(tenant_id: str, app_id: str):
def del_api_token(api_token_id: str):
db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
_delete_records(
"""select id from api_tokens where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_api_token,
"api token",
)
def _delete_installed_apps(tenant_id: str, app_id: str):
def del_installed_app(installed_app_id: str):
db.session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
_delete_records(
"""select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_installed_app,
"installed app",
)
def _delete_recommended_apps(tenant_id: str, app_id: str):
def del_recommended_app(recommended_app_id: str):
db.session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(
synchronize_session=False
)
_delete_records(
"""select id from recommended_apps where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_recommended_app,
"recommended app",
)
def _delete_app_annotation_data(tenant_id: str, app_id: str):
def del_annotation_hit_history(annotation_hit_history_id: str):
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
synchronize_session=False
)
_delete_records(
"""select id from app_annotation_hit_histories where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_annotation_hit_history,
"annotation hit history",
)
def del_annotation_setting(annotation_setting_id: str):
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
synchronize_session=False
)
_delete_records(
"""select id from app_annotation_settings where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_annotation_setting,
"annotation setting",
)
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def del_dataset_join(dataset_join_id: str):
db.session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_dataset_joins where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_dataset_join,
"dataset join",
)
def _delete_app_workflows(tenant_id: str, app_id: str):
def del_workflow(workflow_id: str):
db.session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_workflow,
"workflow",
)
def _delete_app_workflow_runs(tenant_id: str, app_id: str):
"""Delete all workflow runs for an app using the service repository."""
session_maker = sessionmaker(bind=db.engine)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
deleted_count = workflow_run_repo.delete_runs_by_app(
tenant_id=tenant_id,
app_id=app_id,
batch_size=1000,
)
logger.info("Deleted %s workflow runs for app %s", deleted_count, app_id)
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
"""Delete all workflow node executions for an app using the service repository."""
session_maker = sessionmaker(bind=db.engine)
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
deleted_count = node_execution_repo.delete_executions_by_app(
tenant_id=tenant_id,
app_id=app_id,
batch_size=1000,
)
logger.info("Deleted %s workflow node executions for app %s", deleted_count, app_id)
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def del_workflow_app_log(workflow_app_log_id: str):
db.session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(
synchronize_session=False
)
_delete_records(
"""select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_workflow_app_log,
"workflow app log",
)
def _delete_app_conversations(tenant_id: str, app_id: str):
def del_conversation(conversation_id: str):
db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
_delete_records(
"""select id from conversations where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_conversation,
"conversation",
)
def _delete_conversation_variables(*, app_id: str):
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
with db.engine.connect() as conn:
conn.execute(stmt)
conn.commit()
logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green"))
def _delete_app_messages(tenant_id: str, app_id: str):
def del_message(message_id: str):
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(
synchronize_session=False
)
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
synchronize_session=False
)
db.session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
synchronize_session=False
)
db.session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
db.session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
db.session.query(Message).where(Message.id == message_id).delete()
_delete_records(
"""select id from messages where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_message,
"message",
)
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def del_tool_provider(tool_provider_id: str):
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
synchronize_session=False
)
_delete_records(
"""select id from tool_workflow_providers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_tool_provider,
"tool workflow provider",
)
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def del_tag_binding(tag_binding_id: str):
db.session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
_delete_records(
"""select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_tag_binding,
"tag binding",
)
def _delete_end_users(tenant_id: str, app_id: str):
def del_end_user(end_user_id: str):
db.session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
_delete_records(
"""select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_end_user,
"end user",
)
def _delete_trace_app_configs(tenant_id: str, app_id: str):
def del_trace_app_config(trace_app_config_id: str):
db.session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(
synchronize_session=False
)
_delete_records(
"""select id from trace_app_config where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_trace_app_config,
"trace app config",
)
def _delete_draft_variables(app_id: str):
"""Delete all workflow draft variables for an app in batches."""
return delete_draft_variables_batch(app_id, batch_size=1000)
def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
"""
Delete draft variables for an app in batches.
This function now handles cleanup of associated Offload data including:
- WorkflowDraftVariableFile records
- UploadFile records
- Object storage files
Args:
app_id: The ID of the app whose draft variables should be deleted
batch_size: Number of records to delete per batch
Returns:
Total number of records deleted
"""
if batch_size <= 0:
raise ValueError("batch_size must be positive")
total_deleted = 0
total_files_deleted = 0
while True:
with db.engine.begin() as conn:
# Get a batch of draft variable IDs along with their file_ids
query_sql = """
SELECT id, file_id FROM workflow_draft_variables
WHERE app_id = :app_id
LIMIT :batch_size
"""
result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
rows = list(result)
if not rows:
break
draft_var_ids = [row[0] for row in rows]
file_ids = [row[1] for row in rows if row[1] is not None]
# Clean up associated Offload data first
if file_ids:
files_deleted = _delete_draft_variable_offload_data(conn, file_ids)
total_files_deleted += files_deleted
# Delete the draft variables
delete_sql = """
DELETE FROM workflow_draft_variables
WHERE id IN :ids
"""
deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)})
batch_deleted = deleted_result.rowcount
total_deleted += batch_deleted
logger.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green"))
logger.info(
click.style(
f"Deleted {total_deleted} total draft variables for app {app_id}. "
f"Cleaned up {total_files_deleted} total associated files.",
fg="green",
)
)
return total_deleted
def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
"""
Delete Offload data associated with WorkflowDraftVariable file_ids.
This function:
1. Finds WorkflowDraftVariableFile records by file_ids
2. Deletes associated files from object storage
3. Deletes UploadFile records
4. Deletes WorkflowDraftVariableFile records
Args:
conn: Database connection
file_ids: List of WorkflowDraftVariableFile IDs
Returns:
Number of files cleaned up
"""
from extensions.ext_storage import storage
if not file_ids:
return 0
files_deleted = 0
try:
# Get WorkflowDraftVariableFile records and their associated UploadFile keys
query_sql = """
SELECT wdvf.id, uf.key, uf.id as upload_file_id
FROM workflow_draft_variable_files wdvf
JOIN upload_files uf ON wdvf.upload_file_id = uf.id
WHERE wdvf.id IN :file_ids
"""
result = conn.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
file_records = list(result)
# Delete from object storage and collect upload file IDs
upload_file_ids = []
for _, storage_key, upload_file_id in file_records:
try:
storage.delete(storage_key)
upload_file_ids.append(upload_file_id)
files_deleted += 1
except Exception:
logging.exception("Failed to delete storage object %s", storage_key)
# Continue with database cleanup even if storage deletion fails
upload_file_ids.append(upload_file_id)
# Delete UploadFile records
if upload_file_ids:
delete_upload_files_sql = """
DELETE FROM upload_files
WHERE id IN :upload_file_ids
"""
conn.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
# Delete WorkflowDraftVariableFile records
delete_variable_files_sql = """
DELETE FROM workflow_draft_variable_files
WHERE id IN :file_ids
"""
conn.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
except Exception:
logging.exception("Error deleting draft variable offload data:")
# Don't raise, as we want to continue with the main deletion process
return files_deleted
def _delete_app_triggers(tenant_id: str, app_id: str):
def del_app_trigger(trigger_id: str):
db.session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_app_trigger,
"app trigger",
)
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
def del_plugin_trigger(trigger_id: str):
db.session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
synchronize_session=False
)
_delete_records(
"""select id from workflow_plugin_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_plugin_trigger,
"workflow plugin trigger",
)
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
def del_webhook_trigger(trigger_id: str):
db.session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
synchronize_session=False
)
_delete_records(
"""select id from workflow_webhook_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_webhook_trigger,
"workflow webhook trigger",
)
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
def del_schedule_plan(plan_id: str):
db.session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(
synchronize_session=False
)
_delete_records(
"""select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_schedule_plan,
"workflow schedule plan",
)
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
def del_trigger_log(log_id: str):
db.session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_trigger_log,
"workflow trigger log",
)
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
while True:
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query_sql), params)
if rs.rowcount == 0:
break
for i in rs:
record_id = str(i.id)
try:
delete_func(record_id)
db.session.commit()
logger.info(click.style(f"Deleted {name} {record_id}", fg="green"))
except Exception:
logger.exception("Error occurred while deleting %s %s", name, record_id)
continue
rs.close()

View File

@@ -0,0 +1,76 @@
import logging
import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Document, DocumentSegment
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def remove_document_from_index_task(document_id: str):
"""
Async Remove document from index
:param document_id: document id
Usage: remove_document_from_index.delay(document_id)
"""
logger.info(click.style(f"Start remove document segments from index: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).where(Document.id == document_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
if document.indexing_status != "completed":
logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red"))
db.session.close()
return
indexing_cache_key = f"document_{document.id}_indexing"
try:
dataset = document.dataset
if not dataset:
raise Exception("Document has no dataset")
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids:
try:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
except Exception:
logger.exception("clean dataset %s from index failed", dataset.id)
# update segment to disable
db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
{
DocumentSegment.enabled: False,
DocumentSegment.disabled_at: naive_utc_now(),
DocumentSegment.disabled_by: document.disabled_by,
DocumentSegment.updated_at: naive_utc_now(),
}
)
db.session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"Document removed from index: {document.id} latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("remove document from index failed")
if not document.archived:
document.enabled = True
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()

View File

@@ -0,0 +1,125 @@
import logging
import time
import click
from celery import shared_task
from sqlalchemy import select
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_id: str):
"""
Async process document
:param dataset_id:
:param document_ids:
:param user_id:
Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id)
"""
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
return
user = db.session.query(Account).where(Account.id == user_id).first()
if not user:
logger.info(click.style(f"User not found: {user_id}", fg="red"))
return
tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
if not tenant:
raise ValueError("Tenant not found")
user.current_tenant = tenant
for document_id in document_ids:
retry_indexing_cache_key = f"document_{document_id}_is_retried"
# check document limit
features = FeatureService.get_features(tenant.id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
redis_client.delete(retry_indexing_cache_key)
return
logger.info(click.style(f"Start retry document: {document_id}", fg="green"))
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
return
try:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
db.session.add(document)
db.session.commit()
if dataset.runtime_mode == "rag_pipeline":
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.retry_error_document(dataset, document, user)
else:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(retry_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
logger.info(click.style(str(ex), fg="yellow"))
redis_client.delete(retry_indexing_cache_key)
logger.exception("retry_document_indexing_task failed, document_id: %s", document_id)
end_at = time.perf_counter()
logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception(
"retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
)
raise e
finally:
db.session.close()

View File

@@ -0,0 +1,95 @@
import logging
import time
import click
from celery import shared_task
from sqlalchemy import select
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def sync_website_document_indexing_task(dataset_id: str, document_id: str):
"""
Async process document
:param dataset_id:
:param document_id:
Usage: sync_website_document_indexing_task.delay(dataset_id, document_id)
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
raise ValueError("Dataset not found")
sync_indexing_cache_key = f"document_{document_id}_is_sync"
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
redis_client.delete(sync_indexing_cache_key)
return
logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
return
try:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
db.session.add(document)
db.session.commit()
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(sync_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
logger.info(click.style(str(ex), fg="yellow"))
redis_client.delete(sync_indexing_cache_key)
logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id)
end_at = time.perf_counter()
logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green"))

View File

@@ -0,0 +1,521 @@
"""
Celery tasks for async trigger processing.
These tasks handle trigger workflow execution asynchronously
to avoid blocking the main request thread.
"""
import json
import logging
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any
from celery import shared_task
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.plugin.impl.exc import PluginInvokeError
from core.trigger.debug.event_bus import TriggerDebugEventBus
from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key
from core.trigger.entities.entities import TriggerProviderEntity
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.workflow.enums import NodeType, WorkflowExecutionStatus
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from enums.quota_type import QuotaType, unlimited
from extensions.ext_database import db
from models.enums import (
AppTriggerType,
CreatorUserRole,
WorkflowRunTriggeredFrom,
WorkflowTriggerStatus,
)
from models.model import EndUser
from models.provider_ids import TriggerProviderID
from models.trigger import TriggerSubscription, WorkflowPluginTrigger, WorkflowTriggerLog
from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowRun
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
from services.workflow.entities import PluginTriggerData, PluginTriggerDispatchData, PluginTriggerMetadata
from services.workflow.queue_dispatcher import QueueDispatcherManager
logger = logging.getLogger(__name__)
# Use workflow queue for trigger processing
TRIGGER_QUEUE = "triggered_workflow_dispatcher"
def dispatch_trigger_debug_event(
events: list[str],
user_id: str,
timestamp: int,
request_id: str,
subscription: TriggerSubscription,
) -> int:
debug_dispatched = 0
try:
for event_name in events:
pool_key: str = build_plugin_pool_key(
name=event_name,
tenant_id=subscription.tenant_id,
subscription_id=subscription.id,
provider_id=subscription.provider_id,
)
trigger_debug_event: PluginTriggerDebugEvent = PluginTriggerDebugEvent(
timestamp=timestamp,
user_id=user_id,
name=event_name,
request_id=request_id,
subscription_id=subscription.id,
provider_id=subscription.provider_id,
)
debug_dispatched += TriggerDebugEventBus.dispatch(
tenant_id=subscription.tenant_id,
event=trigger_debug_event,
pool_key=pool_key,
)
logger.debug(
"Trigger debug dispatched %d sessions to pool %s for event %s for subscription %s provider %s",
debug_dispatched,
pool_key,
event_name,
subscription.id,
subscription.provider_id,
)
return debug_dispatched
except Exception:
logger.exception("Failed to dispatch to debug sessions")
return 0
def _get_latest_workflows_by_app_ids(
session: Session, subscribers: Sequence[WorkflowPluginTrigger]
) -> Mapping[str, Workflow]:
"""Get the latest workflows by app_ids"""
workflow_query = (
select(Workflow.app_id, func.max(Workflow.created_at).label("max_created_at"))
.where(
Workflow.app_id.in_({t.app_id for t in subscribers}),
Workflow.version != Workflow.VERSION_DRAFT,
)
.group_by(Workflow.app_id)
.subquery()
)
workflows = session.scalars(
select(Workflow).join(
workflow_query,
(Workflow.app_id == workflow_query.c.app_id) & (Workflow.created_at == workflow_query.c.max_created_at),
)
).all()
return {w.app_id: w for w in workflows}
def _record_trigger_failure_log(
*,
session: Session,
workflow: Workflow,
plugin_trigger: WorkflowPluginTrigger,
subscription: TriggerSubscription,
trigger_metadata: PluginTriggerMetadata,
end_user: EndUser | None,
error_message: str,
event_name: str,
request_id: str,
) -> None:
"""
Persist a workflow run, workflow app log, and trigger log entry for failed trigger invocations.
"""
now = datetime.now(UTC)
if end_user:
created_by_role = CreatorUserRole.END_USER
created_by = end_user.id
else:
created_by_role = CreatorUserRole.ACCOUNT
created_by = subscription.user_id
failure_inputs = {
"event_name": event_name,
"subscription_id": subscription.id,
"request_id": request_id,
"plugin_trigger_id": plugin_trigger.id,
}
workflow_run = WorkflowRun(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
type=workflow.type,
triggered_from=WorkflowRunTriggeredFrom.PLUGIN.value,
version=workflow.version,
graph=workflow.graph,
inputs=json.dumps(failure_inputs),
status=WorkflowExecutionStatus.FAILED.value,
outputs="{}",
error=error_message,
elapsed_time=0.0,
total_tokens=0,
total_steps=0,
created_by_role=created_by_role.value,
created_by=created_by,
created_at=now,
finished_at=now,
exceptions_count=0,
)
session.add(workflow_run)
session.flush()
workflow_app_log = WorkflowAppLog(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value,
created_by_role=created_by_role.value,
created_by=created_by,
)
session.add(workflow_app_log)
dispatcher = QueueDispatcherManager.get_dispatcher(subscription.tenant_id)
queue_name = dispatcher.get_queue_name()
trigger_data = PluginTriggerData(
app_id=plugin_trigger.app_id,
tenant_id=subscription.tenant_id,
workflow_id=workflow.id,
root_node_id=plugin_trigger.node_id,
inputs={},
trigger_metadata=trigger_metadata,
plugin_id=subscription.provider_id,
endpoint_id=subscription.endpoint_id,
)
trigger_log = WorkflowTriggerLog(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
root_node_id=plugin_trigger.node_id,
trigger_metadata=trigger_metadata.model_dump_json(),
trigger_type=AppTriggerType.TRIGGER_PLUGIN,
trigger_data=trigger_data.model_dump_json(),
inputs=json.dumps({}),
status=WorkflowTriggerStatus.FAILED,
error=error_message,
queue_name=queue_name,
retry_count=0,
created_by_role=created_by_role.value,
created_by=created_by,
triggered_at=now,
finished_at=now,
elapsed_time=0.0,
total_tokens=0,
outputs=None,
celery_task_id=None,
)
session.add(trigger_log)
session.commit()
def dispatch_triggered_workflow(
user_id: str,
subscription: TriggerSubscription,
event_name: str,
request_id: str,
) -> int:
"""Process triggered workflows.
Args:
subscription: The trigger subscription
event: The trigger entity that was activated
request_id: The ID of the stored request in storage system
"""
request = TriggerHttpRequestCachingService.get_request(request_id)
payload = TriggerHttpRequestCachingService.get_payload(request_id)
subscribers: list[WorkflowPluginTrigger] = TriggerSubscriptionOperatorService.get_subscriber_triggers(
tenant_id=subscription.tenant_id, subscription_id=subscription.id, event_name=event_name
)
if not subscribers:
logger.warning(
"No workflows found for trigger event '%s' in subscription '%s'",
event_name,
subscription.id,
)
return 0
dispatched_count = 0
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
)
trigger_entity: TriggerProviderEntity = provider_controller.entity
with Session(db.engine) as session:
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
type=InvokeFrom.TRIGGER,
tenant_id=subscription.tenant_id,
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
user_id=user_id,
)
for plugin_trigger in subscribers:
# Get workflow from mapping
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
if not workflow:
logger.error(
"Workflow not found for app %s",
plugin_trigger.app_id,
)
continue
# Find the trigger node in the workflow
event_node = None
for node_id, node_config in workflow.walk_nodes(NodeType.TRIGGER_PLUGIN):
if node_id == plugin_trigger.node_id:
event_node = node_config
break
if not event_node:
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
continue
# invoke trigger
trigger_metadata = PluginTriggerMetadata(
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
endpoint_id=subscription.endpoint_id,
provider_id=subscription.provider_id,
event_name=event_name,
icon_filename=trigger_entity.identity.icon or "",
icon_dark_filename=trigger_entity.identity.icon_dark or "",
)
# consume quota before invoking trigger
quota_charge = unlimited()
try:
quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
logger.info(
"Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id
)
return 0
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
invoke_response: TriggerInvokeEventResponse | None = None
try:
invoke_response = TriggerManager.invoke_trigger_event(
tenant_id=subscription.tenant_id,
user_id=user_id,
provider_id=TriggerProviderID(subscription.provider_id),
event_name=event_name,
parameters=node_data.resolve_parameters(
parameter_schemas=provider_controller.get_event_parameters(event_name=event_name)
),
credentials=subscription.credentials,
credential_type=CredentialType.of(subscription.credential_type),
subscription=subscription.to_entity(),
request=request,
payload=payload,
)
except PluginInvokeError as e:
quota_charge.refund()
error_message = e.to_user_friendly_error(plugin_name=trigger_entity.identity.name)
try:
end_user = end_users.get(plugin_trigger.app_id)
_record_trigger_failure_log(
session=session,
workflow=workflow,
plugin_trigger=plugin_trigger,
subscription=subscription,
trigger_metadata=trigger_metadata,
end_user=end_user,
error_message=error_message,
event_name=event_name,
request_id=request_id,
)
except Exception:
logger.exception(
"Failed to record trigger failure log for app %s",
plugin_trigger.app_id,
)
continue
except Exception:
quota_charge.refund()
logger.exception(
"Failed to invoke trigger event for app %s",
plugin_trigger.app_id,
)
continue
if invoke_response is not None and invoke_response.cancelled:
quota_charge.refund()
logger.info(
"Trigger ignored for app %s with trigger event %s",
plugin_trigger.app_id,
event_name,
)
continue
# Create trigger data for async execution
trigger_data = PluginTriggerData(
app_id=plugin_trigger.app_id,
tenant_id=subscription.tenant_id,
workflow_id=workflow.id,
root_node_id=plugin_trigger.node_id,
plugin_id=subscription.provider_id,
endpoint_id=subscription.endpoint_id,
inputs=invoke_response.variables,
trigger_metadata=trigger_metadata,
)
# Trigger async workflow
try:
end_user = end_users.get(plugin_trigger.app_id)
if not end_user:
raise ValueError(f"End user not found for app {plugin_trigger.app_id}")
AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data)
dispatched_count += 1
logger.info(
"Triggered workflow for app %s with trigger event %s",
plugin_trigger.app_id,
event_name,
)
except Exception:
quota_charge.refund()
logger.exception(
"Failed to trigger workflow for app %s",
plugin_trigger.app_id,
)
return dispatched_count
def dispatch_triggered_workflows(
user_id: str,
events: list[str],
subscription: TriggerSubscription,
request_id: str,
) -> int:
dispatched_count = 0
for event_name in events:
try:
dispatched_count += dispatch_triggered_workflow(
user_id=user_id,
subscription=subscription,
event_name=event_name,
request_id=request_id,
)
except Exception:
logger.exception(
"Failed to dispatch trigger '%s' for subscription %s and provider %s. Continuing...",
event_name,
subscription.id,
subscription.provider_id,
)
# Continue processing other triggers even if one fails
continue
logger.info(
"Completed async trigger dispatching: processed %d/%d triggers for subscription %s and provider %s",
dispatched_count,
len(events),
subscription.id,
subscription.provider_id,
)
return dispatched_count
@shared_task(queue=TRIGGER_QUEUE)
def dispatch_triggered_workflows_async(
dispatch_data: Mapping[str, Any],
) -> Mapping[str, Any]:
"""
Dispatch triggers asynchronously.
Args:
endpoint_id: Endpoint ID
provider_id: Provider ID
subscription_id: Subscription ID
timestamp: Timestamp of the event
triggers: List of triggers to dispatch
request_id: Unique ID of the stored request
Returns:
dict: Execution result with status and dispatched trigger count
"""
dispatch_params: PluginTriggerDispatchData = PluginTriggerDispatchData.model_validate(dispatch_data)
user_id = dispatch_params.user_id
tenant_id = dispatch_params.tenant_id
endpoint_id = dispatch_params.endpoint_id
provider_id = dispatch_params.provider_id
subscription_id = dispatch_params.subscription_id
timestamp = dispatch_params.timestamp
events = dispatch_params.events
request_id = dispatch_params.request_id
try:
logger.info(
"Starting trigger dispatching uid=%s, endpoint=%s, events=%s, req_id=%s, sub_id=%s, provider_id=%s",
user_id,
endpoint_id,
events,
request_id,
subscription_id,
provider_id,
)
subscription: TriggerSubscription | None = TriggerProviderService.get_subscription_by_id(
tenant_id=tenant_id,
subscription_id=subscription_id,
)
if not subscription:
logger.error("Subscription not found: %s", subscription_id)
return {"status": "failed", "error": "Subscription not found"}
workflow_dispatched = dispatch_triggered_workflows(
user_id=user_id,
events=events,
subscription=subscription,
request_id=request_id,
)
debug_dispatched = dispatch_trigger_debug_event(
events=events,
user_id=user_id,
timestamp=timestamp,
request_id=request_id,
subscription=subscription,
)
return {
"status": "completed",
"total_count": len(events),
"workflows": workflow_dispatched,
"debug_events": debug_dispatched,
}
except Exception as e:
logger.exception(
"Error in async trigger dispatching for endpoint %s data %s for subscription %s and provider %s",
endpoint_id,
dispatch_data,
subscription_id,
provider_id,
)
return {
"status": "failed",
"error": str(e),
}

View File

@@ -0,0 +1,119 @@
import logging
import time
from collections.abc import Mapping
from typing import Any
from celery import shared_task
from sqlalchemy.orm import Session
from configs import dify_config
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.utils.locks import build_trigger_refresh_lock_key
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.trigger import TriggerSubscription
from services.trigger.trigger_provider_service import TriggerProviderService
logger = logging.getLogger(__name__)
def _now_ts() -> int:
return int(time.time())
def _load_subscription(session: Session, tenant_id: str, subscription_id: str) -> TriggerSubscription | None:
return session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
def _refresh_oauth_if_expired(tenant_id: str, subscription: TriggerSubscription, now: int) -> None:
threshold_seconds: int = int(dify_config.TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS)
if (
subscription.credential_expires_at != -1
and int(subscription.credential_expires_at) <= now + threshold_seconds
and CredentialType.of(subscription.credential_type) == CredentialType.OAUTH2
):
logger.info(
"Refreshing OAuth token: tenant=%s subscription_id=%s expires_at=%s now=%s",
tenant_id,
subscription.id,
subscription.credential_expires_at,
now,
)
try:
result: Mapping[str, Any] = TriggerProviderService.refresh_oauth_token(
tenant_id=tenant_id, subscription_id=subscription.id
)
logger.info(
"OAuth token refreshed: tenant=%s subscription_id=%s result=%s", tenant_id, subscription.id, result
)
except Exception:
logger.exception("OAuth refresh failed: tenant=%s subscription_id=%s", tenant_id, subscription.id)
def _refresh_subscription_if_expired(
tenant_id: str,
subscription: TriggerSubscription,
now: int,
) -> None:
threshold_seconds: int = int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS)
if subscription.expires_at == -1 or int(subscription.expires_at) > now + threshold_seconds:
logger.debug(
"Subscription not due: tenant=%s subscription_id=%s expires_at=%s now=%s threshold=%s",
tenant_id,
subscription.id,
subscription.expires_at,
now,
threshold_seconds,
)
return
try:
result: Mapping[str, Any] = TriggerProviderService.refresh_subscription(
tenant_id=tenant_id, subscription_id=subscription.id, now=now
)
logger.info(
"Subscription refreshed: tenant=%s subscription_id=%s result=%s",
tenant_id,
subscription.id,
result.get("result"),
)
except Exception:
logger.exception("Subscription refresh failed: tenant=%s id=%s", tenant_id, subscription.id)
@shared_task(queue="trigger_refresh_executor")
def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None:
"""Refresh a trigger subscription if needed, guarded by a Redis in-flight lock."""
lock_key: str = build_trigger_refresh_lock_key(tenant_id, subscription_id)
if not redis_client.get(lock_key):
logger.debug("Refresh lock missing, skip: %s", lock_key)
return
logger.info("Begin subscription refresh: tenant=%s id=%s", tenant_id, subscription_id)
try:
now: int = _now_ts()
with Session(db.engine) as session:
subscription: TriggerSubscription | None = _load_subscription(session, tenant_id, subscription_id)
if not subscription:
logger.warning("Subscription not found: tenant=%s id=%s", tenant_id, subscription_id)
return
logger.debug(
"Loaded subscription: tenant=%s id=%s cred_exp=%s sub_exp=%s now=%s",
tenant_id,
subscription.id,
subscription.credential_expires_at,
subscription.expires_at,
now,
)
_refresh_oauth_if_expired(tenant_id=tenant_id, subscription=subscription, now=now)
_refresh_subscription_if_expired(tenant_id=tenant_id, subscription=subscription, now=now)
finally:
try:
redis_client.delete(lock_key)
logger.debug("Lock released: %s", lock_key)
except Exception:
# Best-effort lock cleanup
logger.warning("Failed to release lock: %s", lock_key, exc_info=True)

View File

@@ -0,0 +1,32 @@
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand
from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue
class AsyncWorkflowCFSPlanEntity(WorkflowScheduleCFSPlanEntity):
"""
Trigger workflow CFS plan entity.
"""
queue: AsyncWorkflowQueue
class AsyncWorkflowCFSPlanScheduler(CFSPlanScheduler):
"""
Trigger workflow CFS plan scheduler.
"""
plan: AsyncWorkflowCFSPlanEntity
def can_schedule(self) -> SchedulerCommand:
"""
Check if the workflow can be scheduled.
"""
if self.plan.queue in [AsyncWorkflowQueue.PROFESSIONAL_QUEUE, AsyncWorkflowQueue.TEAM_QUEUE]:
"""
permitted all paid users to schedule the workflow any time
"""
return SchedulerCommand.NONE
# FIXME: avoid the sandbox user's workflow at a running state for ever
return SchedulerCommand.RESOURCE_LIMIT_REACHED

View File

@@ -0,0 +1,25 @@
from enum import StrEnum
from configs import dify_config
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
# Determine queue names based on edition
if dify_config.EDITION == "CLOUD":
# Cloud edition: separate queues for different tiers
_professional_queue = "workflow_professional"
_team_queue = "workflow_team"
_sandbox_queue = "workflow_sandbox"
AsyncWorkflowSystemStrategy = WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice
else:
# Community edition: single workflow queue (not dataset)
_professional_queue = "workflow"
_team_queue = "workflow"
_sandbox_queue = "workflow"
AsyncWorkflowSystemStrategy = WorkflowScheduleCFSPlanEntity.Strategy.Nop
class AsyncWorkflowQueue(StrEnum):
# Define constants
PROFESSIONAL_QUEUE = _professional_queue
TEAM_QUEUE = _team_queue
SANDBOX_QUEUE = _sandbox_queue

View File

@@ -0,0 +1,22 @@
"""
Celery tasks for asynchronous workflow execution storage operations.
These tasks provide asynchronous storage capabilities for workflow execution data,
improving performance by offloading storage operations to background workers.
"""
from celery import shared_task # type: ignore[import-untyped]
from sqlalchemy.orm import Session
from extensions.ext_database import db
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
@shared_task(queue="workflow_draft_var", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_execution_task(
self,
deletions: list[DraftVarFileDeletion],
):
with Session(bind=db.engine) as session, session.begin():
srv = WorkflowDraftVariableService(session=session)
srv.delete_workflow_draft_variable_file(deletions=deletions)

View File

@@ -0,0 +1,136 @@
"""
Celery tasks for asynchronous workflow execution storage operations.
These tasks provide asynchronous storage capabilities for workflow execution data,
improving performance by offloading storage operations to background workers.
"""
import json
import logging
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_execution_task(
self,
execution_data: dict,
tenant_id: str,
app_id: str,
triggered_from: str,
creator_user_id: str,
creator_user_role: str,
) -> bool:
"""
Asynchronously save or update a workflow execution to the database.
Args:
execution_data: Serialized WorkflowExecution data
tenant_id: Tenant ID for multi-tenancy
app_id: Application ID
triggered_from: Source of the execution trigger
creator_user_id: ID of the user who created the execution
creator_user_role: Role of the user who created the execution
Returns:
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
# Deserialize execution data
execution = WorkflowExecution.model_validate(execution_data)
# Check if workflow run already exists
existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_))
if existing_run:
# Update existing workflow run
_update_workflow_run_from_execution(existing_run, execution)
logger.debug("Updated existing workflow run: %s", execution.id_)
else:
# Create new workflow run
workflow_run = _create_workflow_run_from_execution(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
session.add(workflow_run)
logger.debug("Created new workflow run: %s", execution.id_)
session.commit()
return True
except Exception as e:
logger.exception("Failed to save workflow execution %s", execution_data.get("id_", "unknown"))
# Retry the task with exponential backoff
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
def _create_workflow_run_from_execution(
execution: WorkflowExecution,
tenant_id: str,
app_id: str,
triggered_from: WorkflowRunTriggeredFrom,
creator_user_id: str,
creator_user_role: CreatorUserRole,
) -> WorkflowRun:
"""
Create a WorkflowRun database model from a WorkflowExecution domain entity.
"""
workflow_run = WorkflowRun()
workflow_run.id = execution.id_
workflow_run.tenant_id = tenant_id
workflow_run.app_id = app_id
workflow_run.workflow_id = execution.workflow_id
workflow_run.type = execution.workflow_type.value
workflow_run.triggered_from = triggered_from.value
workflow_run.version = execution.workflow_version
json_converter = WorkflowRuntimeTypeConverter()
workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph))
workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs))
workflow_run.status = execution.status.value
workflow_run.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
workflow_run.error = execution.error_message
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.created_by_role = creator_user_role.value
workflow_run.created_by = creator_user_id
workflow_run.created_at = execution.started_at
workflow_run.finished_at = execution.finished_at
return workflow_run
def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution):
"""
Update a WorkflowRun database model from a WorkflowExecution domain entity.
"""
json_converter = WorkflowRuntimeTypeConverter()
workflow_run.status = execution.status.value
workflow_run.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
workflow_run.error = execution.error_message
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.finished_at = execution.finished_at

View File

@@ -0,0 +1,169 @@
"""
Celery tasks for asynchronous workflow node execution storage operations.
These tasks provide asynchronous storage capabilities for workflow node execution data,
improving performance by offloading storage operations to background workers.
"""
import json
import logging
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowNodeExecutionModel
from models.workflow import WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_node_execution_task(
self,
execution_data: dict,
tenant_id: str,
app_id: str,
triggered_from: str,
creator_user_id: str,
creator_user_role: str,
) -> bool:
"""
Asynchronously save or update a workflow node execution to the database.
Args:
execution_data: Serialized WorkflowNodeExecution data
tenant_id: Tenant ID for multi-tenancy
app_id: Application ID
triggered_from: Source of the execution trigger
creator_user_id: ID of the user who created the execution
creator_user_role: Role of the user who created the execution
Returns:
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
# Deserialize execution data
execution = WorkflowNodeExecution.model_validate(execution_data)
# Check if node execution already exists
existing_execution = session.scalar(
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id)
)
if existing_execution:
# Update existing node execution
_update_node_execution_from_domain(existing_execution, execution)
logger.debug("Updated existing workflow node execution: %s", execution.id)
else:
# Create new node execution
node_execution = _create_node_execution_from_domain(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
session.add(node_execution)
logger.debug("Created new workflow node execution: %s", execution.id)
session.commit()
return True
except Exception as e:
logger.exception("Failed to save workflow node execution %s", execution_data.get("id", "unknown"))
# Retry the task with exponential backoff
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
def _create_node_execution_from_domain(
execution: WorkflowNodeExecution,
tenant_id: str,
app_id: str,
triggered_from: WorkflowNodeExecutionTriggeredFrom,
creator_user_id: str,
creator_user_role: CreatorUserRole,
) -> WorkflowNodeExecutionModel:
"""
Create a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
"""
node_execution = WorkflowNodeExecutionModel()
node_execution.id = execution.id
node_execution.tenant_id = tenant_id
node_execution.app_id = app_id
node_execution.workflow_id = execution.workflow_id
node_execution.triggered_from = triggered_from.value
node_execution.workflow_run_id = execution.workflow_execution_id
node_execution.index = execution.index
node_execution.predecessor_node_id = execution.predecessor_node_id
node_execution.node_id = execution.node_id
node_execution.node_type = execution.node_type.value
node_execution.title = execution.title
node_execution.node_execution_id = execution.node_execution_id
# Serialize complex data as JSON
json_converter = WorkflowRuntimeTypeConverter()
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
node_execution.process_data = (
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
)
node_execution.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
# Convert metadata enum keys to strings for JSON serialization
if execution.metadata:
metadata_for_json = {
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
}
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
else:
node_execution.execution_metadata = "{}"
node_execution.status = execution.status.value
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.created_by_role = creator_user_role.value
node_execution.created_by = creator_user_id
node_execution.created_at = execution.created_at
node_execution.finished_at = execution.finished_at
return node_execution
def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution):
"""
Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
"""
# Update serialized data
json_converter = WorkflowRuntimeTypeConverter()
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
node_execution.process_data = (
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
)
node_execution.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
# Convert metadata enum keys to strings for JSON serialization
if execution.metadata:
metadata_for_json = {
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
}
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
else:
node_execution.execution_metadata = "{}"
# Update other fields
node_execution.status = execution.status.value
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.finished_at = execution.finished_at

View File

@@ -0,0 +1,73 @@
import logging
from celery import shared_task
from sqlalchemy.orm import sessionmaker
from core.workflow.nodes.trigger_schedule.exc import (
ScheduleExecutionError,
ScheduleNotFoundError,
TenantOwnerNotFoundError,
)
from enums.quota_type import QuotaType, unlimited
from extensions.ext_database import db
from models.trigger import WorkflowSchedulePlan
from services.async_workflow_service import AsyncWorkflowService
from services.errors.app import QuotaExceededError
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.schedule_service import ScheduleService
from services.workflow.entities import ScheduleTriggerData
logger = logging.getLogger(__name__)
@shared_task(queue="schedule_executor")
def run_schedule_trigger(schedule_id: str) -> None:
"""
Execute a scheduled workflow trigger.
Note: No retry logic needed as schedules will run again at next interval.
The execution result is tracked via WorkflowTriggerLog.
Raises:
ScheduleNotFoundError: If schedule doesn't exist
TenantOwnerNotFoundError: If no owner/admin for tenant
ScheduleExecutionError: If workflow trigger fails
"""
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
schedule = session.get(WorkflowSchedulePlan, schedule_id)
if not schedule:
raise ScheduleNotFoundError(f"Schedule {schedule_id} not found")
tenant_owner = ScheduleService.get_tenant_owner(session, schedule.tenant_id)
if not tenant_owner:
raise TenantOwnerNotFoundError(f"No owner or admin found for tenant {schedule.tenant_id}")
quota_charge = unlimited()
try:
quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
return
try:
# Production dispatch: Trigger the workflow normally
response = AsyncWorkflowService.trigger_workflow_async(
session=session,
user=tenant_owner,
trigger_data=ScheduleTriggerData(
app_id=schedule.app_id,
root_node_id=schedule.node_id,
inputs={},
tenant_id=schedule.tenant_id,
),
)
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
except Exception as e:
quota_charge.refund()
raise ScheduleExecutionError(
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
) from e