dify
This commit is contained in:
0
dify/api/tasks/__init__.py
Normal file
0
dify/api/tasks/__init__.py
Normal file
117
dify/api/tasks/add_document_to_index_task.py
Normal file
117
dify/api/tasks/add_document_to_index_task.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import DatasetAutoDisableLog, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def add_document_to_index_task(dataset_document_id: str):
|
||||
"""
|
||||
Async Add document to index
|
||||
:param dataset_document_id:
|
||||
|
||||
Usage: add_document_to_index_task.delay(dataset_document_id)
|
||||
"""
|
||||
logger.info(click.style(f"Start add document to index: {dataset_document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
if dataset_document.indexing_status != "completed":
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
indexing_cache_key = f"document_{dataset_document.id}_indexing"
|
||||
|
||||
try:
|
||||
dataset = dataset_document.dataset
|
||||
if not dataset:
|
||||
raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
|
||||
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
documents.append(document)
|
||||
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.load(dataset, documents)
|
||||
|
||||
# delete auto disable log
|
||||
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()
|
||||
|
||||
# update segment to enable
|
||||
db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
|
||||
{
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.disabled_at: None,
|
||||
DocumentSegment.disabled_by: None,
|
||||
DocumentSegment.updated_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("add document to index failed")
|
||||
dataset_document.enabled = False
|
||||
dataset_document.disabled_at = naive_utc_now()
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
db.session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
62
dify/api/tasks/annotation/add_annotation_to_index_task.py
Normal file
62
dify/api/tasks/annotation/add_annotation_to_index_task.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def add_annotation_to_index_task(
|
||||
annotation_id: str, question: str, tenant_id: str, app_id: str, collection_binding_id: str
|
||||
):
|
||||
"""
|
||||
Add annotation to index.
|
||||
:param annotation_id: annotation id
|
||||
:param question: question
|
||||
:param tenant_id: tenant id
|
||||
:param app_id: app id
|
||||
:param collection_binding_id: embedding binding id
|
||||
|
||||
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
|
||||
"""
|
||||
logger.info(click.style(f"Start build index for annotation: {annotation_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id, "annotation"
|
||||
)
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
|
||||
document = Document(
|
||||
page_content=question, metadata={"annotation_id": annotation_id, "app_id": app_id, "doc_id": annotation_id}
|
||||
)
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
vector.create([document], duplicate_check=True)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Build index successful for annotation: {annotation_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Build index for annotation failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
94
dify/api/tasks/annotation/batch_import_annotations_task.py
Normal file
94
dify/api/tasks/annotation/batch_import_annotations_task.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppAnnotationSetting, MessageAnnotation
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str, user_id: str):
|
||||
"""
|
||||
Add annotation to index.
|
||||
:param job_id: job_id
|
||||
:param content_list: content list
|
||||
:param app_id: app id
|
||||
:param tenant_id: tenant id
|
||||
:param user_id: user_id
|
||||
|
||||
"""
|
||||
logger.info(click.style(f"Start batch import annotation: {job_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
||||
# get app info
|
||||
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
|
||||
|
||||
if app:
|
||||
try:
|
||||
documents = []
|
||||
for content in content_list:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db.session.flush()
|
||||
|
||||
document = Document(
|
||||
page_content=content["question"],
|
||||
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
|
||||
)
|
||||
documents.append(document)
|
||||
# if annotation reply is enabled , batch add annotations' index
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
|
||||
if app_annotation_setting:
|
||||
dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
app_annotation_setting.collection_binding_id, "annotation"
|
||||
)
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
raise NotFound("App annotation setting not found")
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
vector.create(documents, duplicate_check=True)
|
||||
|
||||
db.session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Build index successful for batch import annotation: {} latency: {}".format(
|
||||
job_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
|
||||
redis_client.setex(indexing_error_msg_key, 600, str(e))
|
||||
logger.exception("Build index for batch import annotations failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
44
dify/api/tasks/annotation/delete_annotation_index_task.py
Normal file
44
dify/api/tasks/annotation/delete_annotation_index_task.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str, collection_binding_id: str):
|
||||
"""
|
||||
Async delete annotation index task
|
||||
"""
|
||||
logger.info(click.style(f"Start delete app annotation index: {app_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
try:
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id, "annotation"
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
|
||||
try:
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
vector.delete_by_metadata_field("annotation_id", annotation_id)
|
||||
except Exception:
|
||||
logger.exception("Delete annotation index failed when annotation deleted.")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("Annotation deleted index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
71
dify/api/tasks/annotation/disable_annotation_reply_task.py
Normal file
71
dify/api/tasks/annotation/disable_annotation_reply_task.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import exists, select
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppAnnotationSetting, MessageAnnotation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
|
||||
"""
|
||||
Async enable annotation reply task
|
||||
"""
|
||||
logger.info(click.style(f"Start delete app annotations index: {app_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
# get app info
|
||||
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
|
||||
annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
|
||||
if not app:
|
||||
logger.info(click.style(f"App not found: {app_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
|
||||
if not app_annotation_setting:
|
||||
logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
|
||||
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
|
||||
|
||||
try:
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
collection_binding_id=app_annotation_setting.collection_binding_id,
|
||||
)
|
||||
|
||||
try:
|
||||
if annotations_exists:
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
vector.delete()
|
||||
except Exception:
|
||||
logger.exception("Delete annotation index failed when annotation deleted.")
|
||||
redis_client.setex(disable_app_annotation_job_key, 600, "completed")
|
||||
|
||||
# delete annotation setting
|
||||
db.session.delete(app_annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception("Annotation batch deleted index failed")
|
||||
redis_client.setex(disable_app_annotation_job_key, 600, "error")
|
||||
disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}"
|
||||
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
|
||||
finally:
|
||||
redis_client.delete(disable_app_annotation_key)
|
||||
db.session.close()
|
||||
124
dify/api/tasks/annotation/enable_annotation_reply_task.py
Normal file
124
dify/api/tasks/annotation/enable_annotation_reply_task.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppAnnotationSetting, MessageAnnotation
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def enable_annotation_reply_task(
|
||||
job_id: str,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
score_threshold: float,
|
||||
embedding_provider_name: str,
|
||||
embedding_model_name: str,
|
||||
):
|
||||
"""
|
||||
Async enable annotation reply task
|
||||
"""
|
||||
logger.info(click.style(f"Start add app annotation to index: {app_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
# get app info
|
||||
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
|
||||
|
||||
if not app:
|
||||
logger.info(click.style(f"App not found: {app_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
|
||||
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
|
||||
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
|
||||
|
||||
try:
|
||||
documents = []
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name, embedding_model_name, "annotation"
|
||||
)
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
if annotation_setting:
|
||||
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
|
||||
old_dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
annotation_setting.collection_binding_id, "annotation"
|
||||
)
|
||||
)
|
||||
if old_dataset_collection_binding and annotations:
|
||||
old_dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=old_dataset_collection_binding.provider_name,
|
||||
embedding_model=old_dataset_collection_binding.model_name,
|
||||
collection_binding_id=old_dataset_collection_binding.id,
|
||||
)
|
||||
|
||||
old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
try:
|
||||
old_vector.delete()
|
||||
except Exception as e:
|
||||
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
|
||||
annotation_setting.score_threshold = score_threshold
|
||||
annotation_setting.collection_binding_id = dataset_collection_binding.id
|
||||
annotation_setting.updated_user_id = user_id
|
||||
annotation_setting.updated_at = naive_utc_now()
|
||||
db.session.add(annotation_setting)
|
||||
else:
|
||||
new_app_annotation_setting = AppAnnotationSetting(
|
||||
app_id=app_id,
|
||||
score_threshold=score_threshold,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
created_user_id=user_id,
|
||||
updated_user_id=user_id,
|
||||
)
|
||||
db.session.add(new_app_annotation_setting)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
if annotations:
|
||||
for annotation in annotations:
|
||||
document = Document(
|
||||
page_content=annotation.question,
|
||||
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
try:
|
||||
vector.delete_by_metadata_field("app_id", app_id)
|
||||
except Exception as e:
|
||||
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
|
||||
vector.create(documents)
|
||||
db.session.commit()
|
||||
redis_client.setex(enable_app_annotation_job_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"App annotations added to index: {app_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception("Annotation batch created index failed")
|
||||
redis_client.setex(enable_app_annotation_job_key, 600, "error")
|
||||
enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}"
|
||||
redis_client.setex(enable_app_annotation_error_key, 600, str(e))
|
||||
db.session.rollback()
|
||||
finally:
|
||||
redis_client.delete(enable_app_annotation_key)
|
||||
db.session.close()
|
||||
63
dify/api/tasks/annotation/update_annotation_to_index_task.py
Normal file
63
dify/api/tasks/annotation/update_annotation_to_index_task.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def update_annotation_to_index_task(
|
||||
annotation_id: str, question: str, tenant_id: str, app_id: str, collection_binding_id: str
|
||||
):
|
||||
"""
|
||||
Update annotation to index.
|
||||
:param annotation_id: annotation id
|
||||
:param question: question
|
||||
:param tenant_id: tenant id
|
||||
:param app_id: app id
|
||||
:param collection_binding_id: embedding binding id
|
||||
|
||||
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
|
||||
"""
|
||||
logger.info(click.style(f"Start update index for annotation: {annotation_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id, "annotation"
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
|
||||
document = Document(
|
||||
page_content=question, metadata={"annotation_id": annotation_id, "app_id": app_id, "doc_id": annotation_id}
|
||||
)
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
vector.delete_by_metadata_field("annotation_id", annotation_id)
|
||||
vector.add_texts([document])
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Build index successful for annotation: {annotation_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Build index for annotation failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
196
dify/api/tasks/async_workflow_tasks.py
Normal file
196
dify/api/tasks/async_workflow_tasks.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Celery tasks for async workflow execution.
|
||||
|
||||
These tasks handle workflow execution for different subscription tiers
|
||||
with appropriate retry policies and error handling.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.layers.trigger_post_layer import TriggerPostLayer
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole, WorkflowTriggerStatus
|
||||
from models.model import App, EndUser, Tenant
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from models.workflow import Workflow
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import WorkflowNotFoundError
|
||||
from services.workflow.entities import (
|
||||
TriggerData,
|
||||
WorkflowTaskData,
|
||||
)
|
||||
from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity, AsyncWorkflowCFSPlanScheduler
|
||||
from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkflowSystemStrategy
|
||||
|
||||
|
||||
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
|
||||
def execute_workflow_professional(task_data_dict: dict[str, Any]):
|
||||
"""Execute workflow for professional tier with highest priority"""
|
||||
task_data = WorkflowTaskData.model_validate(task_data_dict)
|
||||
cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
|
||||
queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE,
|
||||
schedule_strategy=AsyncWorkflowSystemStrategy,
|
||||
granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY,
|
||||
)
|
||||
_execute_workflow_common(
|
||||
task_data,
|
||||
AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity),
|
||||
cfs_plan_scheduler_entity,
|
||||
)
|
||||
|
||||
|
||||
@shared_task(queue=AsyncWorkflowQueue.TEAM_QUEUE)
|
||||
def execute_workflow_team(task_data_dict: dict[str, Any]):
|
||||
"""Execute workflow for team tier"""
|
||||
task_data = WorkflowTaskData.model_validate(task_data_dict)
|
||||
cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
|
||||
queue=AsyncWorkflowQueue.TEAM_QUEUE,
|
||||
schedule_strategy=AsyncWorkflowSystemStrategy,
|
||||
granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY,
|
||||
)
|
||||
_execute_workflow_common(
|
||||
task_data,
|
||||
AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity),
|
||||
cfs_plan_scheduler_entity,
|
||||
)
|
||||
|
||||
|
||||
@shared_task(queue=AsyncWorkflowQueue.SANDBOX_QUEUE)
|
||||
def execute_workflow_sandbox(task_data_dict: dict[str, Any]):
|
||||
"""Execute workflow for free tier with lower retry limit"""
|
||||
task_data = WorkflowTaskData.model_validate(task_data_dict)
|
||||
cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity(
|
||||
queue=AsyncWorkflowQueue.SANDBOX_QUEUE,
|
||||
schedule_strategy=AsyncWorkflowSystemStrategy,
|
||||
granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY,
|
||||
)
|
||||
_execute_workflow_common(
|
||||
task_data,
|
||||
AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity),
|
||||
cfs_plan_scheduler_entity,
|
||||
)
|
||||
|
||||
|
||||
def _build_generator_args(trigger_data: TriggerData) -> dict[str, Any]:
|
||||
"""Build args passed into WorkflowAppGenerator.generate for Celery executions."""
|
||||
|
||||
args: dict[str, Any] = {
|
||||
"inputs": dict(trigger_data.inputs),
|
||||
"files": list(trigger_data.files),
|
||||
SKIP_PREPARE_USER_INPUTS_KEY: True,
|
||||
}
|
||||
return args
|
||||
|
||||
|
||||
def _execute_workflow_common(
|
||||
task_data: WorkflowTaskData,
|
||||
cfs_plan_scheduler: AsyncWorkflowCFSPlanScheduler,
|
||||
cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity,
|
||||
):
|
||||
"""Execute workflow with common logic and trigger log updates."""
|
||||
|
||||
# Create a new session for this task
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
|
||||
# Get trigger log
|
||||
trigger_log = trigger_log_repo.get_by_id(task_data.workflow_trigger_log_id)
|
||||
|
||||
if not trigger_log:
|
||||
# This should not happen, but handle gracefully
|
||||
return
|
||||
|
||||
# Reconstruct execution data from trigger log
|
||||
trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data)
|
||||
|
||||
# Update status to running
|
||||
trigger_log.status = WorkflowTriggerStatus.RUNNING
|
||||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
# Get app and workflow models
|
||||
app_model = session.scalar(select(App).where(App.id == trigger_log.app_id))
|
||||
|
||||
if not app_model:
|
||||
raise WorkflowNotFoundError(f"App not found: {trigger_log.app_id}")
|
||||
|
||||
workflow = session.scalar(select(Workflow).where(Workflow.id == trigger_log.workflow_id))
|
||||
if not workflow:
|
||||
raise WorkflowNotFoundError(f"Workflow not found: {trigger_log.workflow_id}")
|
||||
|
||||
user = _get_user(session, trigger_log)
|
||||
|
||||
# Execute workflow using WorkflowAppGenerator
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
# Prepare args matching AppGenerateService.generate format
|
||||
args = _build_generator_args(trigger_data)
|
||||
|
||||
# If workflow_id was specified, add it to args
|
||||
if trigger_data.workflow_id:
|
||||
args["workflow_id"] = str(trigger_data.workflow_id)
|
||||
|
||||
# Execute the workflow with the trigger type
|
||||
generator.generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
call_depth=0,
|
||||
triggered_from=trigger_data.trigger_from,
|
||||
root_node_id=trigger_data.root_node_id,
|
||||
graph_engine_layers=[
|
||||
# TODO: Re-enable TimeSliceLayer after the HITL release.
|
||||
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Calculate elapsed time for failed execution
|
||||
elapsed_time = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
# Update trigger log with failure
|
||||
trigger_log.status = WorkflowTriggerStatus.FAILED
|
||||
trigger_log.error = str(e)
|
||||
trigger_log.finished_at = datetime.now(UTC)
|
||||
trigger_log.elapsed_time = elapsed_time
|
||||
trigger_log_repo.update(trigger_log)
|
||||
|
||||
# Final failure - no retry logic (simplified like RAG tasks)
|
||||
session.commit()
|
||||
|
||||
|
||||
def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser:
|
||||
"""Compose user from trigger log"""
|
||||
tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id))
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant not found: {trigger_log.tenant_id}")
|
||||
|
||||
# Get user from trigger log
|
||||
if trigger_log.created_by_role == CreatorUserRole.ACCOUNT:
|
||||
user = session.scalar(select(Account).where(Account.id == trigger_log.created_by))
|
||||
if user:
|
||||
user.current_tenant = tenant
|
||||
else: # CreatorUserRole.END_USER
|
||||
user = session.scalar(select(EndUser).where(EndUser.id == trigger_log.created_by))
|
||||
|
||||
if not user:
|
||||
raise ValueError(f"User not found: {trigger_log.created_by} (role: {trigger_log.created_by_role})")
|
||||
|
||||
return user
|
||||
92
dify/api/tasks/batch_clean_document_task.py
Normal file
92
dify/api/tasks/batch_clean_document_task.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
|
||||
"""
|
||||
Clean document when document deleted.
|
||||
:param document_ids: document ids
|
||||
:param dataset_id: dataset id
|
||||
:param doc_form: doc_form
|
||||
:param file_ids: file ids
|
||||
|
||||
Usage: batch_clean_document_task.delay(document_ids, dataset_id)
|
||||
"""
|
||||
logger.info(click.style("Start batch clean documents when documents deleted", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
if not doc_form:
|
||||
raise ValueError("doc_form is required")
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
|
||||
db.session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id.in_(document_ids),
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
||||
).all()
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
try:
|
||||
if image_file and image_file.key:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
upload_file_id,
|
||||
)
|
||||
db.session.delete(image_file)
|
||||
db.session.delete(segment)
|
||||
|
||||
db.session.commit()
|
||||
if file_ids:
|
||||
files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
|
||||
for file in files:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
|
||||
db.session.delete(file)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned documents when documents deleted failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
151
dify/api/tasks/batch_create_segment_to_index_task.py
Normal file
151
dify/api/tasks/batch_create_segment_to_index_task.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import logging
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import pandas as pd
|
||||
from celery import shared_task
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.vector_service import VectorService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def batch_create_segment_to_index_task(
|
||||
job_id: str,
|
||||
upload_file_id: str,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
Async batch create segment to index
|
||||
:param job_id:
|
||||
:param upload_file_id:
|
||||
:param dataset_id:
|
||||
:param document_id:
|
||||
:param tenant_id:
|
||||
:param user_id:
|
||||
|
||||
Usage: batch_create_segment_to_index_task.delay(job_id, upload_file_id, dataset_id, document_id, tenant_id, user_id)
|
||||
"""
|
||||
logger.info(click.style(f"Start batch create segment jobId: {job_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||
|
||||
try:
|
||||
dataset = db.session.get(Dataset, dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not exist.")
|
||||
|
||||
dataset_document = db.session.get(Document, document_id)
|
||||
if not dataset_document:
|
||||
raise ValueError("Document not exist.")
|
||||
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
raise ValueError("Document is not available.")
|
||||
|
||||
upload_file = db.session.get(UploadFile, upload_file_id)
|
||||
if not upload_file:
|
||||
raise ValueError("UploadFile not found.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file.key, file_path)
|
||||
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[segment["content"] for segment in content]
|
||||
)
|
||||
else:
|
||||
tokens_list = [0] * len(content)
|
||||
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == dataset_document.id)
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=naive_utc_now(),
|
||||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
db.session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
|
||||
assert dataset_document.word_count is not None
|
||||
dataset_document.word_count += word_count_change
|
||||
db.session.add(dataset_document)
|
||||
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
||||
db.session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Segments batch created index failed")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
finally:
|
||||
db.session.close()
|
||||
154
dify/api/tasks/clean_dataset_task.py
Normal file
154
dify/api/tasks/clean_dataset_task.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
Dataset,
|
||||
DatasetMetadata,
|
||||
DatasetMetadataBinding,
|
||||
DatasetProcessRule,
|
||||
DatasetQuery,
|
||||
Document,
|
||||
DocumentSegment,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Add import statement for ValueError
|
||||
@shared_task(queue="dataset")
|
||||
def clean_dataset_task(
|
||||
dataset_id: str,
|
||||
tenant_id: str,
|
||||
indexing_technique: str,
|
||||
index_struct: str,
|
||||
collection_binding_id: str,
|
||||
doc_form: str,
|
||||
):
|
||||
"""
|
||||
Clean dataset when dataset deleted.
|
||||
:param dataset_id: dataset id
|
||||
:param tenant_id: tenant id
|
||||
:param indexing_technique: indexing technique
|
||||
:param index_struct: index struct dict
|
||||
:param collection_binding_id: collection binding id
|
||||
:param doc_form: dataset form
|
||||
|
||||
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
|
||||
"""
|
||||
logger.info(click.style(f"Start clean dataset when dataset deleted: {dataset_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = Dataset(
|
||||
id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique=indexing_technique,
|
||||
index_struct=index_struct,
|
||||
collection_binding_id=collection_binding_id,
|
||||
)
|
||||
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
|
||||
|
||||
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
|
||||
# This ensures all invalid doc_form values are properly handled
|
||||
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
|
||||
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
|
||||
doc_form = IndexType.PARAGRAPH_INDEX
|
||||
logger.info(
|
||||
click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
|
||||
)
|
||||
|
||||
# Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
|
||||
# This ensures Document/Segment deletion can continue even if vector database cleanup fails
|
||||
try:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
|
||||
logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
|
||||
# Continue with document and segment deletion even if vector cleanup fails
|
||||
logger.info(
|
||||
click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
|
||||
)
|
||||
|
||||
if documents is None or len(documents) == 0:
|
||||
logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
|
||||
else:
|
||||
logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
|
||||
|
||||
for document in documents:
|
||||
db.session.delete(document)
|
||||
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
if image_file is None:
|
||||
continue
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
upload_file_id,
|
||||
)
|
||||
db.session.delete(image_file)
|
||||
db.session.delete(segment)
|
||||
|
||||
db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
|
||||
db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
|
||||
db.session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
|
||||
# delete dataset metadata
|
||||
db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
|
||||
db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
|
||||
# delete files
|
||||
if documents:
|
||||
for document in documents:
|
||||
try:
|
||||
if document.data_source_type == "upload_file":
|
||||
if document.data_source_info:
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
if not file:
|
||||
continue
|
||||
storage.delete(file.key)
|
||||
db.session.delete(file)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
db.session.commit()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", fg="green")
|
||||
)
|
||||
except Exception:
|
||||
# Add rollback to prevent dirty session state in case of exceptions
|
||||
# This ensures the database session is properly cleaned up
|
||||
try:
|
||||
db.session.rollback()
|
||||
logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("Failed to rollback database session")
|
||||
|
||||
logger.exception("Cleaned dataset when dataset deleted failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
90
dify/api/tasks/clean_document_task.py
Normal file
90
dify/api/tasks/clean_document_task.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: str | None):
|
||||
"""
|
||||
Clean document when document deleted.
|
||||
:param document_id: document id
|
||||
:param dataset_id: dataset id
|
||||
:param doc_form: doc_form
|
||||
:param file_id: file id
|
||||
|
||||
Usage: clean_document_task.delay(document_id, dataset_id)
|
||||
"""
|
||||
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
if image_file is None:
|
||||
continue
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
upload_file_id,
|
||||
)
|
||||
db.session.delete(image_file)
|
||||
db.session.delete(segment)
|
||||
|
||||
db.session.commit()
|
||||
if file_id:
|
||||
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if file:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
||||
db.session.delete(file)
|
||||
db.session.commit()
|
||||
|
||||
# delete dataset metadata binding
|
||||
db.session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id == document_id,
|
||||
).delete()
|
||||
db.session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document deleted failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
60
dify/api/tasks/clean_notion_document_task.py
Normal file
60
dify/api/tasks/clean_notion_document_task.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def clean_notion_document_task(document_ids: list[str], dataset_id: str):
|
||||
"""
|
||||
Clean document when document deleted.
|
||||
:param document_ids: document ids
|
||||
:param dataset_id: dataset id
|
||||
|
||||
Usage: clean_notion_document_task.delay(document_ids, dataset_id)
|
||||
"""
|
||||
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
for document_id in document_ids:
|
||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
db.session.delete(document)
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Clean document when import form notion document deleted end :: {} latency: {}".format(
|
||||
dataset_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when import form notion document deleted failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
99
dify/api/tasks/create_segment_to_index_task.py
Normal file
99
dify/api/tasks/create_segment_to_index_task.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = None):
|
||||
"""
|
||||
Async create segment to index
|
||||
:param segment_id:
|
||||
:param keywords:
|
||||
Usage: create_segment_to_index_task.delay(segment_id)
|
||||
"""
|
||||
logger.info(click.style(f"Start create segment to index: {segment_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
|
||||
if not segment:
|
||||
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
if segment.status != "waiting":
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
|
||||
try:
|
||||
# update segment status to indexing
|
||||
db.session.query(DocumentSegment).filter_by(id=segment.id).update(
|
||||
{
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.indexing_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
dataset = segment.dataset
|
||||
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
dataset_document = segment.document
|
||||
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.load(dataset, [document])
|
||||
|
||||
# update segment to completed
|
||||
db.session.query(DocumentSegment).filter_by(id=segment.id).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.completed_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception("create segment to index failed")
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.status = "error"
|
||||
segment.error = str(e)
|
||||
db.session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
171
dify/api/tasks/deal_dataset_index_update_task.py
Normal file
171
dify/api/tasks/deal_dataset_index_update_task.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
"""
|
||||
Async deal dataset from index
|
||||
:param dataset_id: dataset_id
|
||||
:param action: action
|
||||
Usage: deal_dataset_index_update_task.delay(dataset_id, action)
|
||||
"""
|
||||
logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if action == "upgrade":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
# add from vector index
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
# clean keywords
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
elif action == "update":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
# add new index
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
# clean index
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
# update from vector index
|
||||
try:
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
# clean collection
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Deal dataset vector index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
169
dify/api/tasks/deal_dataset_vector_index_task.py
Normal file
169
dify/api/tasks/deal_dataset_vector_index_task.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]):
|
||||
"""
|
||||
Async deal dataset from index
|
||||
:param dataset_id: dataset_id
|
||||
:param action: action
|
||||
Usage: deal_dataset_vector_index_task.delay(dataset_id, action)
|
||||
"""
|
||||
logger.info(click.style(f"Start deal dataset vector index: {dataset_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if action == "remove":
|
||||
index_processor.clean(dataset, None, with_keywords=False)
|
||||
elif action == "add":
|
||||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
).all()
|
||||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
# add from vector index
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
elif action == "update":
|
||||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
).all()
|
||||
# add new index
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
# clean index
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
# update from vector index
|
||||
try:
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
# clean collection
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("Deal dataset vector index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
26
dify/api/tasks/delete_account_task.py
Normal file
26
dify/api/tasks/delete_account_task.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import logging
|
||||
|
||||
from celery import shared_task
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
from tasks.mail_account_deletion_task import send_deletion_success_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def delete_account_task(account_id):
|
||||
account = db.session.query(Account).where(Account.id == account_id).first()
|
||||
try:
|
||||
BillingService.delete_account(account_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete account %s from billing service.", account_id)
|
||||
raise
|
||||
|
||||
if not account:
|
||||
logger.error("Account %s not found.", account_id)
|
||||
return
|
||||
# send success email
|
||||
send_deletion_success_task.delay(account.email)
|
||||
70
dify/api/tasks/delete_conversation_task.py
Normal file
70
dify/api/tasks/delete_conversation_task.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable
|
||||
from models.model import Message, MessageAnnotation, MessageFeedback
|
||||
from models.tools import ToolConversationVariables, ToolFile
|
||||
from models.web import PinnedConversation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="conversation")
|
||||
def delete_conversation_related_data(conversation_id: str):
|
||||
"""
|
||||
Delete related data conversation in correct order from datatbase to respect foreign key constraints
|
||||
|
||||
Args:
|
||||
conversation_id: conversation Id
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
click.style(f"Starting to delete conversation data from db for conversation_id {conversation_id}", fg="green")
|
||||
)
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.query(ToolConversationVariables).where(
|
||||
ToolConversationVariables.conversation_id == conversation_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
db.session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
|
||||
|
||||
db.session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
|
||||
|
||||
db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Succeeded cleaning data from db for conversation_id {conversation_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
|
||||
db.session.rollback()
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
58
dify/api/tasks/delete_segment_from_index_task.py
Normal file
58
dify/api/tasks/delete_segment_from_index_task.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def delete_segment_from_index_task(
|
||||
index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None
|
||||
):
|
||||
"""
|
||||
Async Remove segment from index
|
||||
:param index_node_ids:
|
||||
:param dataset_id:
|
||||
:param document_id:
|
||||
|
||||
Usage: delete_segment_from_index_task.delay(index_node_ids, dataset_id, document_id)
|
||||
"""
|
||||
logger.info(click.style("Start delete segment from index", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
|
||||
return
|
||||
|
||||
dataset_document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
if not dataset_document:
|
||||
return
|
||||
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logging.info("Document not in valid state for index operations, skipping")
|
||||
return
|
||||
doc_form = dataset_document.doc_form
|
||||
|
||||
# Proceed with index cleanup using the index_node_ids directly
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(
|
||||
dataset,
|
||||
index_node_ids,
|
||||
with_keywords=True,
|
||||
delete_child_chunks=True,
|
||||
precomputed_child_node_ids=child_node_ids,
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("delete segment from index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
68
dify/api/tasks/disable_segment_from_index_task.py
Normal file
68
dify/api/tasks/disable_segment_from_index_task.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def disable_segment_from_index_task(segment_id: str):
|
||||
"""
|
||||
Async disable segment from index
|
||||
:param segment_id:
|
||||
|
||||
Usage: disable_segment_from_index_task.delay(segment_id)
|
||||
"""
|
||||
logger.info(click.style(f"Start disable segment from index: {segment_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
|
||||
if not segment:
|
||||
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
if segment.status != "completed":
|
||||
logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
|
||||
try:
|
||||
dataset = segment.dataset
|
||||
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
dataset_document = segment.document
|
||||
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.clean(dataset, [segment.index_node_id])
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment removed from index: {segment.id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("remove segment from index failed")
|
||||
segment.enabled = True
|
||||
db.session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
84
dify/api/tasks/disable_segments_from_index_task.py
Normal file
84
dify/api/tasks/disable_segments_from_index_task.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def disable_segments_from_index_task(segment_ids: list, dataset_id: str, document_id: str):
|
||||
"""
|
||||
Async disable segments from index
|
||||
:param segment_ids: list of segment ids
|
||||
:param dataset_id: dataset id
|
||||
:param document_id: document id
|
||||
|
||||
Usage: disable_segments_from_index_task.delay(segment_ids, dataset_id, document_id)
|
||||
"""
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
|
||||
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
# sync index processor
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
if not segments:
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
try:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
# update segment error msg
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
).update(
|
||||
{
|
||||
"disabled_at": None,
|
||||
"disabled_by": None,
|
||||
"enabled": True,
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
finally:
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
124
dify/api/tasks/document_indexing_sync_task.py
Normal file
124
dify/api/tasks/document_indexing_sync_task.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.source import DataSourceOauthBinding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
"""
|
||||
Async update document
|
||||
:param dataset_id:
|
||||
:param document_id:
|
||||
|
||||
Usage: document_indexing_sync_task.delay(dataset_id, document_id)
|
||||
"""
|
||||
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
data_source_info = document.data_source_info_dict
|
||||
if document.data_source_type == "notion_import":
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_page_id" not in data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no notion page found")
|
||||
workspace_id = data_source_info["notion_workspace_id"]
|
||||
page_id = data_source_info["notion_page_id"]
|
||||
page_type = data_source_info["type"]
|
||||
page_edited_time = data_source_info["last_edited_time"]
|
||||
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.where(
|
||||
sa.and_(
|
||||
DataSourceOauthBinding.tenant_id == document.tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not data_source_binding:
|
||||
raise ValueError("Data source binding not found.")
|
||||
|
||||
loader = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
tenant_id=document.tenant_id,
|
||||
)
|
||||
|
||||
last_edited_time = loader.get_notion_last_edited_time()
|
||||
|
||||
# check the page is updated
|
||||
if last_edited_time != page_edited_time:
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
# delete all document segment and index
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Cleaned document when document update data source or process rule: {} latency: {}".format(
|
||||
document_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document update data source or process rule failed")
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
168
dify/api/tasks/document_indexing_task.py
Normal file
168
dify/api/tasks/document_indexing_task.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def document_indexing_task(dataset_id: str, document_ids: list):
|
||||
"""
|
||||
Async process document
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
.. warning:: TO BE DEPRECATED
|
||||
This function will be deprecated and removed in a future version.
|
||||
Use normal_document_indexing_task or priority_document_indexing_task instead.
|
||||
|
||||
Usage: document_indexing_task.delay(dataset_id, document_ids)
|
||||
"""
|
||||
logger.warning("document indexing legacy mode received: %s - %s", dataset_id, document_ids)
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
|
||||
|
||||
def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
"""
|
||||
Process document for tasks
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: _document_indexing(dataset_id, document_ids)
|
||||
"""
|
||||
documents = []
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
|
||||
db.session.close()
|
||||
return
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
count = len(document_ids)
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription."
|
||||
)
|
||||
except Exception as e:
|
||||
for document_id in document_ids:
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
|
||||
if document:
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
|
||||
def _document_indexing_with_tenant_queue(
|
||||
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
|
||||
):
|
||||
try:
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
except Exception:
|
||||
logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id)
|
||||
finally:
|
||||
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
||||
|
||||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||
|
||||
logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks)
|
||||
|
||||
if next_tasks:
|
||||
for next_task in next_tasks:
|
||||
document_task = DocumentTask(**next_task)
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=document_task.tenant_id,
|
||||
dataset_id=document_task.dataset_id,
|
||||
document_ids=document_task.document_ids,
|
||||
)
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def normal_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
||||
"""
|
||||
Async process document
|
||||
:param tenant_id:
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: normal_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
|
||||
"""
|
||||
logger.info("normal document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, normal_document_indexing_task)
|
||||
|
||||
|
||||
@shared_task(queue="priority_dataset")
|
||||
def priority_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
||||
"""
|
||||
Priority async process document
|
||||
:param tenant_id:
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: priority_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
|
||||
"""
|
||||
logger.info("priority document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, priority_document_indexing_task)
|
||||
81
dify/api/tasks/document_indexing_update_task.py
Normal file
81
dify/api/tasks/document_indexing_update_task.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def document_indexing_update_task(dataset_id: str, document_id: str):
|
||||
"""
|
||||
Async update document
|
||||
:param dataset_id:
|
||||
:param document_id:
|
||||
|
||||
Usage: document_indexing_update_task.delay(dataset_id, document_id)
|
||||
"""
|
||||
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
# delete all document segment and index
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Cleaned document when document update data source or process rule: {} latency: {}".format(
|
||||
document_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document update data source or process rule failed")
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
112
dify/api/tasks/duplicate_document_indexing_task.py
Normal file
112
dify/api/tasks/duplicate_document_indexing_task.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
||||
"""
|
||||
Async process document
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: duplicate_document_indexing_task.delay(dataset_id, document_ids)
|
||||
"""
|
||||
documents = []
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset is None:
|
||||
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
count = len(document_ids)
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
current = int(getattr(vector_space, "size", 0) or 0)
|
||||
limit = int(getattr(vector_space, "limit", 0) or 0)
|
||||
if limit > 0 and (current + count) > limit:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have exceeded the limit of "
|
||||
"your subscription."
|
||||
)
|
||||
except Exception as e:
|
||||
for document_id in document_ids:
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
return
|
||||
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
|
||||
if document:
|
||||
# clean old data
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
100
dify/api/tasks/enable_segment_to_index_task.py
Normal file
100
dify/api/tasks/enable_segment_to_index_task.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def enable_segment_to_index_task(segment_id: str):
|
||||
"""
|
||||
Async enable segment to index
|
||||
:param segment_id:
|
||||
|
||||
Usage: enable_segment_to_index_task.delay(segment_id)
|
||||
"""
|
||||
logger.info(click.style(f"Start enable segment to index: {segment_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
|
||||
if not segment:
|
||||
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
if segment.status != "completed":
|
||||
logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
|
||||
try:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
dataset = segment.dataset
|
||||
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
dataset_document = segment.document
|
||||
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
# save vector index
|
||||
index_processor.load(dataset, [document])
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception("enable segment to index failed")
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.status = "error"
|
||||
segment.error = str(e)
|
||||
db.session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
116
dify/api/tasks/enable_segments_to_index_task.py
Normal file
116
dify/api/tasks/enable_segments_to_index_task.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_id: str):
|
||||
"""
|
||||
Async enable segments to index
|
||||
:param segment_ids: list of segment ids
|
||||
:param dataset_id: dataset id
|
||||
:param document_id: document id
|
||||
|
||||
Usage: enable_segments_to_index_task.delay(segment_ids, dataset_id, document_id)
|
||||
"""
|
||||
start_at = time.perf_counter()
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
|
||||
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
# sync index processor
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
try:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": document_id,
|
||||
"dataset_id": dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": document_id,
|
||||
"dataset_id": dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception("enable segments to index failed")
|
||||
# update segment error msg
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
).update(
|
||||
{
|
||||
"error": str(e),
|
||||
"status": "error",
|
||||
"disabled_at": naive_utc_now(),
|
||||
"enabled": False,
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
finally:
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
86
dify/api/tasks/mail_account_deletion_task.py
Normal file
86
dify/api/tasks/mail_account_deletion_task.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_i18n import EmailType, get_email_i18n_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_deletion_success_task(to: str, language: str = "en-US"):
|
||||
"""
|
||||
Send account deletion success email with internationalization support.
|
||||
|
||||
Args:
|
||||
to: Recipient email address
|
||||
language: Language code for email localization
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start send account deletion success email to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.ACCOUNT_DELETION_SUCCESS,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"email": to,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Send account deletion success email to {to}: latency: {end_at - start_at}", fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Send account deletion success email to %s failed", to)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_account_deletion_verification_code(to: str, code: str, language: str = "en-US"):
|
||||
"""
|
||||
Send account deletion verification code email with internationalization support.
|
||||
|
||||
Args:
|
||||
to: Recipient email address
|
||||
code: Verification code
|
||||
language: Language code for email localization
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start send account deletion verification code email to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.ACCOUNT_DELETION_VERIFICATION,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"code": code,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Send account deletion verification code email to {} succeeded: latency: {}".format(
|
||||
to, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Send account deletion verification code email to %s failed", to)
|
||||
80
dify/api/tasks/mail_change_mail_task.py
Normal file
80
dify/api/tasks/mail_change_mail_task.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_i18n import EmailType, get_email_i18n_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_change_mail_task(language: str, to: str, code: str, phase: str):
|
||||
"""
|
||||
Send change email notification with internationalization support.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
code: Email verification code
|
||||
phase: Change email phase ('old_email' or 'new_email')
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start change email mail to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_change_email(
|
||||
language_code=language,
|
||||
to=to,
|
||||
code=code,
|
||||
phase=phase,
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Send change email mail to {to} succeeded: latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("Send change email mail to %s failed", to)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_change_mail_completed_notification_task(language: str, to: str):
|
||||
"""
|
||||
Send change email completed notification with internationalization support.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start change email completed notify mail to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.CHANGE_EMAIL_COMPLETED,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"email": to,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Send change email completed mail to {to} succeeded: latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Send change email completed mail to %s failed", to)
|
||||
46
dify/api/tasks/mail_email_code_login.py
Normal file
46
dify/api/tasks/mail_email_code_login.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_i18n import EmailType, get_email_i18n_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_email_code_login_mail_task(language: str, to: str, code: str):
|
||||
"""
|
||||
Send email code login email with internationalization support.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
code: Email verification code
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start email code login mail to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.EMAIL_CODE_LOGIN,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"code": code,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Send email code login mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Send email code login mail to %s failed", to)
|
||||
61
dify/api/tasks/mail_inner_task.py
Normal file
61
dify/api/tasks/mail_inner_task.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from flask import render_template_string
|
||||
from jinja2.runtime import Context
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
|
||||
from configs import dify_config
|
||||
from configs.feature import TemplateMode
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_i18n import get_email_i18n_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxedEnvironment(ImmutableSandboxedEnvironment):
|
||||
def __init__(self, timeout: int, *args: Any, **kwargs: Any):
|
||||
self._timeout_time = time.time() + timeout
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
if time.time() > self._timeout_time:
|
||||
raise TimeoutError("Template rendering timeout")
|
||||
return super().call(context, obj, *args, **kwargs)
|
||||
|
||||
|
||||
def _render_template_with_strategy(body: str, substitutions: Mapping[str, str]) -> str:
|
||||
mode = dify_config.MAIL_TEMPLATING_MODE
|
||||
timeout = dify_config.MAIL_TEMPLATING_TIMEOUT
|
||||
if mode == TemplateMode.UNSAFE:
|
||||
return render_template_string(body, **substitutions)
|
||||
if mode == TemplateMode.SANDBOX:
|
||||
tmpl = SandboxedEnvironment(timeout=timeout).from_string(body)
|
||||
return tmpl.render(substitutions)
|
||||
if mode == TemplateMode.DISABLED:
|
||||
return body
|
||||
raise ValueError(f"Unsupported mail templating mode: {mode}")
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_inner_email_task(to: list[str], subject: str, body: str, substitutions: Mapping[str, str]):
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start enterprise mail to {to} with subject {subject}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
html_content = _render_template_with_strategy(body, substitutions)
|
||||
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_raw_email(to=to, subject=subject, html_content=html_content)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Send enterprise mail to {to} succeeded: latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("Send enterprise mail to %s failed", to)
|
||||
50
dify/api/tasks/mail_invite_member_task.py
Normal file
50
dify/api/tasks/mail_invite_member_task.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_i18n import EmailType, get_email_i18n_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str):
|
||||
"""
|
||||
Send invite member email with internationalization support.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
token: Invitation token
|
||||
inviter_name: Name of the person sending the invitation
|
||||
workspace_name: Name of the workspace
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start send invite member mail to {to} in workspace {workspace_name}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
url = f"{dify_config.CONSOLE_WEB_URL}/activate?token={token}"
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.INVITE_MEMBER,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"inviter_name": inviter_name,
|
||||
"workspace_name": workspace_name,
|
||||
"url": url,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Send invite member mail to {to} succeeded: latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("Send invite member mail to %s failed", to)
|
||||
131
dify/api/tasks/mail_owner_transfer_task.py
Normal file
131
dify/api/tasks/mail_owner_transfer_task.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_i18n import EmailType, get_email_i18n_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspace: str):
|
||||
"""
|
||||
Send owner transfer confirmation email with internationalization support.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
code: Verification code
|
||||
workspace: Workspace name
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start owner transfer confirm mail to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.OWNER_TRANSFER_CONFIRM,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"code": code,
|
||||
"WorkspaceName": workspace,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Send owner transfer confirm mail to {to} succeeded: latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("owner transfer confirm email mail to %s failed", to)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: str, new_owner_email: str):
|
||||
"""
|
||||
Send old owner transfer notification email with internationalization support.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
workspace: Workspace name
|
||||
new_owner_email: New owner email address
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start old owner transfer notify mail to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.OWNER_TRANSFER_OLD_NOTIFY,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"WorkspaceName": workspace,
|
||||
"NewOwnerEmail": new_owner_email,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Send old owner transfer notify mail to {to} succeeded: latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("old owner transfer notify email mail to %s failed", to)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: str):
|
||||
"""
|
||||
Send new owner transfer notification email with internationalization support.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
workspace: Workspace name
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start new owner transfer notify mail to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.OWNER_TRANSFER_NEW_NOTIFY,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"WorkspaceName": workspace,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Send new owner transfer notify mail to {to} succeeded: latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("new owner transfer notify email mail to %s failed", to)
|
||||
87
dify/api/tasks/mail_register_task.py
Normal file
87
dify/api/tasks/mail_register_task.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_i18n import EmailType, get_email_i18n_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_email_register_mail_task(language: str, to: str, code: str) -> None:
|
||||
"""
|
||||
Send email register email with internationalization support.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
code: Email register code
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start email register mail to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.EMAIL_REGISTER,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"code": code,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Send email register mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Send email register mail to %s failed", to)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_email_register_mail_task_when_account_exist(language: str, to: str, account_name: str) -> None:
|
||||
"""
|
||||
Send email register email with internationalization support when account exist.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start email register mail to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
login_url = f"{dify_config.CONSOLE_WEB_URL}/signin"
|
||||
reset_password_url = f"{dify_config.CONSOLE_WEB_URL}/reset-password"
|
||||
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"login_url": login_url,
|
||||
"reset_password_url": reset_password_url,
|
||||
"account_name": account_name,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Send email register mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Send email register mail to %s failed", to)
|
||||
91
dify/api/tasks/mail_reset_password_task.py
Normal file
91
dify/api/tasks/mail_reset_password_task.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_i18n import EmailType, get_email_i18n_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_reset_password_mail_task(language: str, to: str, code: str):
|
||||
"""
|
||||
Send reset password email with internationalization support.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
code: Reset password code
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start password reset mail to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.RESET_PASSWORD,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"code": code,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Send password reset mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Send password reset mail to %s failed", to)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_reset_password_mail_task_when_account_not_exist(language: str, to: str, is_allow_register: bool) -> None:
|
||||
"""
|
||||
Send reset password email with internationalization support when account not exist.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start password reset mail to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
if is_allow_register:
|
||||
sign_up_url = f"{dify_config.CONSOLE_WEB_URL}/signup"
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"sign_up_url": sign_up_url,
|
||||
},
|
||||
)
|
||||
else:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER,
|
||||
language_code=language,
|
||||
to=to,
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Send password reset mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Send password reset mail to %s failed", to)
|
||||
55
dify/api/tasks/ops_trace_task.py
Normal file
55
dify/api/tasks/ops_trace_task.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from celery import shared_task
|
||||
from flask import current_app
|
||||
|
||||
from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY
|
||||
from core.ops.entities.trace_entity import trace_info_info_map
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import Message
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="ops_trace")
|
||||
def process_trace_tasks(file_info):
|
||||
"""
|
||||
Async process trace tasks
|
||||
Usage: process_trace_tasks.delay(tasks_data)
|
||||
"""
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
|
||||
app_id = file_info.get("app_id")
|
||||
file_id = file_info.get("file_id")
|
||||
file_path = f"{OPS_FILE_PATH}{app_id}/{file_id}.json"
|
||||
file_data = json.loads(storage.load(file_path))
|
||||
trace_info = file_data.get("trace_info")
|
||||
trace_info_type = file_data.get("trace_info_type")
|
||||
trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
|
||||
|
||||
if trace_info.get("message_data"):
|
||||
trace_info["message_data"] = Message.from_dict(data=trace_info["message_data"])
|
||||
if trace_info.get("workflow_data"):
|
||||
trace_info["workflow_data"] = WorkflowRun.from_dict(data=trace_info["workflow_data"])
|
||||
if trace_info.get("documents"):
|
||||
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
|
||||
|
||||
try:
|
||||
if trace_instance:
|
||||
with current_app.app_context():
|
||||
trace_type = trace_info_info_map.get(trace_info_type)
|
||||
if trace_type:
|
||||
trace_info = trace_type(**trace_info)
|
||||
trace_instance.trace(trace_info)
|
||||
logger.info("Processing trace tasks success, app_id: %s", app_id)
|
||||
except Exception as e:
|
||||
logger.info("error:\n\n\n%s\n\n\n\n", e)
|
||||
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
|
||||
redis_client.incr(failed_key)
|
||||
logger.info("Processing trace tasks failed, app_id: %s", app_id)
|
||||
finally:
|
||||
storage.delete(file_path)
|
||||
234
dify/api/tasks/process_tenant_plugin_autoupgrade_check_task.py
Normal file
234
dify/api/tasks/process_tenant_plugin_autoupgrade_check_task.py
Normal file
@@ -0,0 +1,234 @@
|
||||
import json
|
||||
import operator
|
||||
import typing
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.helper import marketplace
|
||||
from core.helper.marketplace import MarketplacePluginDeclaration
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
|
||||
RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3
|
||||
CACHE_REDIS_KEY_PREFIX = "plugin_autoupgrade_check_task:cached_plugin_manifests:"
|
||||
CACHE_REDIS_TTL = 60 * 15 # 15 minutes
|
||||
|
||||
|
||||
def _get_redis_cache_key(plugin_id: str) -> str:
|
||||
"""Generate Redis cache key for plugin manifest."""
|
||||
return f"{CACHE_REDIS_KEY_PREFIX}{plugin_id}"
|
||||
|
||||
|
||||
def _get_cached_manifest(plugin_id: str) -> typing.Union[MarketplacePluginDeclaration, None, bool]:
|
||||
"""
|
||||
Get cached plugin manifest from Redis.
|
||||
Returns:
|
||||
- MarketplacePluginDeclaration: if found in cache
|
||||
- None: if cached as not found (marketplace returned no result)
|
||||
- False: if not in cache at all
|
||||
"""
|
||||
try:
|
||||
key = _get_redis_cache_key(plugin_id)
|
||||
cached_data = redis_client.get(key)
|
||||
if cached_data is None:
|
||||
return False
|
||||
|
||||
cached_json = json.loads(cached_data)
|
||||
if cached_json is None:
|
||||
return None
|
||||
|
||||
return MarketplacePluginDeclaration.model_validate(cached_json)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _set_cached_manifest(plugin_id: str, manifest: typing.Union[MarketplacePluginDeclaration, None]) -> None:
|
||||
"""
|
||||
Cache plugin manifest in Redis.
|
||||
Args:
|
||||
plugin_id: The plugin ID
|
||||
manifest: The manifest to cache, or None if not found in marketplace
|
||||
"""
|
||||
try:
|
||||
key = _get_redis_cache_key(plugin_id)
|
||||
if manifest is None:
|
||||
# Cache the fact that this plugin was not found
|
||||
redis_client.setex(key, CACHE_REDIS_TTL, json.dumps(None))
|
||||
else:
|
||||
# Cache the manifest data
|
||||
redis_client.setex(key, CACHE_REDIS_TTL, manifest.model_dump_json())
|
||||
except Exception:
|
||||
# If Redis fails, continue without caching
|
||||
# traceback.print_exc()
|
||||
pass
|
||||
|
||||
|
||||
def marketplace_batch_fetch_plugin_manifests(
|
||||
plugin_ids_plain_list: list[str],
|
||||
) -> list[MarketplacePluginDeclaration]:
|
||||
"""Fetch plugin manifests with Redis caching support."""
|
||||
cached_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {}
|
||||
not_cached_plugin_ids: list[str] = []
|
||||
|
||||
# Check Redis cache for each plugin
|
||||
for plugin_id in plugin_ids_plain_list:
|
||||
cached_result = _get_cached_manifest(plugin_id)
|
||||
if cached_result is False:
|
||||
# Not in cache, need to fetch
|
||||
not_cached_plugin_ids.append(plugin_id)
|
||||
else:
|
||||
# Either found manifest or cached as None (not found in marketplace)
|
||||
# At this point, cached_result is either MarketplacePluginDeclaration or None
|
||||
if isinstance(cached_result, bool):
|
||||
# This should never happen due to the if condition above, but for type safety
|
||||
continue
|
||||
cached_manifests[plugin_id] = cached_result
|
||||
|
||||
# Fetch uncached plugins from marketplace
|
||||
if not_cached_plugin_ids:
|
||||
manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_cached_plugin_ids)
|
||||
|
||||
# Cache the fetched manifests
|
||||
for manifest in manifests:
|
||||
cached_manifests[manifest.plugin_id] = manifest
|
||||
_set_cached_manifest(manifest.plugin_id, manifest)
|
||||
|
||||
# Cache plugins that were not found in marketplace
|
||||
fetched_plugin_ids = {manifest.plugin_id for manifest in manifests}
|
||||
for plugin_id in not_cached_plugin_ids:
|
||||
if plugin_id not in fetched_plugin_ids:
|
||||
cached_manifests[plugin_id] = None
|
||||
_set_cached_manifest(plugin_id, None)
|
||||
|
||||
# Build result list from cached manifests
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin_id in plugin_ids_plain_list:
|
||||
cached_manifest: typing.Union[MarketplacePluginDeclaration, None] = cached_manifests.get(plugin_id)
|
||||
if cached_manifest is not None:
|
||||
result.append(cached_manifest)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@shared_task(queue="plugin")
|
||||
def process_tenant_plugin_autoupgrade_check_task(
|
||||
tenant_id: str,
|
||||
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting,
|
||||
upgrade_time_of_day: int,
|
||||
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode,
|
||||
exclude_plugins: list[str],
|
||||
include_plugins: list[str],
|
||||
):
|
||||
try:
|
||||
manager = PluginInstaller()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Checking upgradable plugin for tenant: {tenant_id}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
if strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED:
|
||||
return
|
||||
|
||||
# get plugin_ids to check
|
||||
plugin_ids: list[tuple[str, str, str]] = [] # plugin_id, version, unique_identifier
|
||||
click.echo(click.style(f"Upgrade mode: {upgrade_mode}", fg="green"))
|
||||
|
||||
if upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL and include_plugins:
|
||||
all_plugins = manager.list_plugins(tenant_id)
|
||||
|
||||
for plugin in all_plugins:
|
||||
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id in include_plugins:
|
||||
plugin_ids.append(
|
||||
(
|
||||
plugin.plugin_id,
|
||||
plugin.version,
|
||||
plugin.plugin_unique_identifier,
|
||||
)
|
||||
)
|
||||
|
||||
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE:
|
||||
# get all plugins and remove excluded plugins
|
||||
all_plugins = manager.list_plugins(tenant_id)
|
||||
plugin_ids = [
|
||||
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
|
||||
for plugin in all_plugins
|
||||
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id not in exclude_plugins
|
||||
]
|
||||
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL:
|
||||
all_plugins = manager.list_plugins(tenant_id)
|
||||
plugin_ids = [
|
||||
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
|
||||
for plugin in all_plugins
|
||||
if plugin.source == PluginInstallationSource.Marketplace
|
||||
]
|
||||
|
||||
if not plugin_ids:
|
||||
return
|
||||
|
||||
plugin_ids_plain_list = [plugin_id for plugin_id, _, _ in plugin_ids]
|
||||
|
||||
manifests = marketplace_batch_fetch_plugin_manifests(plugin_ids_plain_list)
|
||||
|
||||
if not manifests:
|
||||
return
|
||||
|
||||
for manifest in manifests:
|
||||
for plugin_id, version, original_unique_identifier in plugin_ids:
|
||||
if manifest.plugin_id != plugin_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
current_version = version
|
||||
latest_version = manifest.latest_version
|
||||
|
||||
def fix_only_checker(latest_version: str, current_version: str):
|
||||
latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
|
||||
current_version_tuple = tuple(int(val) for val in current_version.split("."))
|
||||
|
||||
if (
|
||||
latest_version_tuple[0] == current_version_tuple[0]
|
||||
and latest_version_tuple[1] == current_version_tuple[1]
|
||||
):
|
||||
return latest_version_tuple[2] != current_version_tuple[2]
|
||||
return False
|
||||
|
||||
version_checker = {
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: operator.ne,
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
|
||||
}
|
||||
|
||||
if version_checker[strategy_setting](latest_version, current_version):
|
||||
# execute upgrade
|
||||
new_unique_identifier = manifest.latest_package_identifier
|
||||
|
||||
marketplace.record_install_plugin_event(new_unique_identifier)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Upgrade plugin: {original_unique_identifier} -> {new_unique_identifier}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
_ = manager.upgrade_plugin(
|
||||
tenant_id,
|
||||
original_unique_identifier,
|
||||
new_unique_identifier,
|
||||
PluginInstallationSource.Marketplace,
|
||||
{
|
||||
"plugin_unique_identifier": new_unique_identifier,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error when upgrading plugin: {e}", fg="red"))
|
||||
# traceback.print_exc()
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error when checking upgradable plugin: {e}", fg="red"))
|
||||
# traceback.print_exc()
|
||||
return
|
||||
187
dify/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py
Normal file
187
dify/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant
|
||||
from models.dataset import Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.file_service import FileService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="priority_pipeline")
|
||||
def priority_rag_pipeline_run_task(
|
||||
rag_pipeline_invoke_entities_file_id: str,
|
||||
tenant_id: str,
|
||||
):
|
||||
"""
|
||||
Async Run rag pipeline task using high priority queue.
|
||||
|
||||
:param rag_pipeline_invoke_entities_file_id: File ID containing serialized RAG pipeline invoke entities
|
||||
:param tenant_id: Tenant ID for the pipeline execution
|
||||
"""
|
||||
# run with threading, thread pool size is 10
|
||||
|
||||
try:
|
||||
start_at = time.perf_counter()
|
||||
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
|
||||
rag_pipeline_invoke_entities_file_id
|
||||
)
|
||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||
|
||||
# Get Flask app object for thread context
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = []
|
||||
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
|
||||
# Submit task to thread pool with Flask app
|
||||
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
|
||||
futures.append(future)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for future in futures:
|
||||
try:
|
||||
future.result() # This will raise any exceptions that occurred in the thread
|
||||
except Exception:
|
||||
logging.exception("Error in pipeline task")
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
raise
|
||||
finally:
|
||||
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline")
|
||||
|
||||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||
logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids)
|
||||
|
||||
if next_file_ids:
|
||||
for next_file_id in next_file_ids:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
priority_rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
file_service = FileService(db.engine)
|
||||
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
|
||||
db.session.close()
|
||||
|
||||
|
||||
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
|
||||
"""Run a single RAG pipeline task within Flask app context."""
|
||||
# Create Flask application context for this thread
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity)
|
||||
user_id = rag_pipeline_invoke_entity_model.user_id
|
||||
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
|
||||
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
|
||||
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
|
||||
streaming = rag_pipeline_invoke_entity_model.streaming
|
||||
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
|
||||
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
|
||||
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Load required entities
|
||||
account = session.query(Account).where(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
|
||||
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
|
||||
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
if workflow_execution_id is None:
|
||||
workflow_execution_id = str(uuid.uuid4())
|
||||
|
||||
# Create application generate entity from dict
|
||||
entity = RagPipelineGenerateEntity.model_validate(application_generate_entity)
|
||||
|
||||
# Create workflow repositories
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = (
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
)
|
||||
|
||||
# Set the user directly in g for preserve_flask_contexts
|
||||
g._login_user = account
|
||||
|
||||
# Copy context for passing to pipeline generator
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Direct execution without creating another thread
|
||||
# Since we're already in a thread pool, no need for nested threading
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
|
||||
pipeline_generator = PipelineGenerator()
|
||||
# Using protected method intentionally for async execution
|
||||
pipeline_generator._generate( # type: ignore[attr-defined]
|
||||
flask_app=flask_app,
|
||||
context=context,
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow_id,
|
||||
user=account,
|
||||
application_generate_entity=entity,
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Error in priority pipeline task")
|
||||
raise
|
||||
187
dify/api/tasks/rag_pipeline/rag_pipeline_run_task.py
Normal file
187
dify/api/tasks/rag_pipeline/rag_pipeline_run_task.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant
|
||||
from models.dataset import Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.file_service import FileService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="pipeline")
|
||||
def rag_pipeline_run_task(
|
||||
rag_pipeline_invoke_entities_file_id: str,
|
||||
tenant_id: str,
|
||||
):
|
||||
"""
|
||||
Async Run rag pipeline task using regular priority queue.
|
||||
|
||||
:param rag_pipeline_invoke_entities_file_id: File ID containing serialized RAG pipeline invoke entities
|
||||
:param tenant_id: Tenant ID for the pipeline execution
|
||||
"""
|
||||
# run with threading, thread pool size is 10
|
||||
|
||||
try:
|
||||
start_at = time.perf_counter()
|
||||
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
|
||||
rag_pipeline_invoke_entities_file_id
|
||||
)
|
||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||
|
||||
# Get Flask app object for thread context
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = []
|
||||
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
|
||||
# Submit task to thread pool with Flask app
|
||||
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
|
||||
futures.append(future)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for future in futures:
|
||||
try:
|
||||
future.result() # This will raise any exceptions that occurred in the thread
|
||||
except Exception:
|
||||
logging.exception("Error in pipeline task")
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
raise
|
||||
finally:
|
||||
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline")
|
||||
|
||||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||
logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids)
|
||||
|
||||
if next_file_ids:
|
||||
for next_file_id in next_file_ids:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
file_service = FileService(db.engine)
|
||||
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
|
||||
db.session.close()
|
||||
|
||||
|
||||
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
|
||||
"""Run a single RAG pipeline task within Flask app context."""
|
||||
# Create Flask application context for this thread
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity)
|
||||
user_id = rag_pipeline_invoke_entity_model.user_id
|
||||
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
|
||||
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
|
||||
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
|
||||
streaming = rag_pipeline_invoke_entity_model.streaming
|
||||
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
|
||||
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
|
||||
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Load required entities
|
||||
account = session.query(Account).where(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
|
||||
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
|
||||
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
if workflow_execution_id is None:
|
||||
workflow_execution_id = str(uuid.uuid4())
|
||||
|
||||
# Create application generate entity from dict
|
||||
entity = RagPipelineGenerateEntity.model_validate(application_generate_entity)
|
||||
|
||||
# Create workflow repositories
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = (
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
)
|
||||
|
||||
# Set the user directly in g for preserve_flask_contexts
|
||||
g._login_user = account
|
||||
|
||||
# Copy context for passing to pipeline generator
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Direct execution without creating another thread
|
||||
# Since we're already in a thread pool, no need for nested threading
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
|
||||
pipeline_generator = PipelineGenerator()
|
||||
# Using protected method intentionally for async execution
|
||||
pipeline_generator._generate( # type: ignore[attr-defined]
|
||||
flask_app=flask_app,
|
||||
context=context,
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow_id,
|
||||
user=account,
|
||||
application_generate_entity=entity,
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Error in pipeline task")
|
||||
raise
|
||||
48
dify/api/tasks/recover_document_indexing_task.py
Normal file
48
dify/api/tasks/recover_document_indexing_task.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def recover_document_indexing_task(dataset_id: str, document_id: str):
|
||||
"""
|
||||
Async recover document
|
||||
:param dataset_id:
|
||||
:param document_id:
|
||||
|
||||
Usage: recover_document_indexing_task.delay(dataset_id, document_id)
|
||||
"""
|
||||
logger.info(click.style(f"Recover document: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
if document.indexing_status in {"waiting", "parsing", "cleaning"}:
|
||||
indexing_runner.run([document])
|
||||
elif document.indexing_status == "splitting":
|
||||
indexing_runner.run_in_splitting_status(document)
|
||||
elif document.indexing_status == "indexing":
|
||||
indexing_runner.run_in_indexing_status(document)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("recover_document_indexing_task failed, document_id: %s", document_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
577
dify/api/tasks/remove_app_and_related_data_task.py
Normal file
577
dify/api/tasks/remove_app_and_related_data_task.py
Normal file
@@ -0,0 +1,577 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from celery import shared_task
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import (
|
||||
ApiToken,
|
||||
AppAnnotationHitHistory,
|
||||
AppAnnotationSetting,
|
||||
AppDatasetJoin,
|
||||
AppMCPServer,
|
||||
AppModelConfig,
|
||||
AppTrigger,
|
||||
Conversation,
|
||||
EndUser,
|
||||
InstalledApp,
|
||||
Message,
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
MessageChain,
|
||||
MessageFeedback,
|
||||
MessageFile,
|
||||
RecommendedApp,
|
||||
Site,
|
||||
TagBinding,
|
||||
TraceAppConfig,
|
||||
WorkflowSchedulePlan,
|
||||
)
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.trigger import WorkflowPluginTrigger, WorkflowTriggerLog, WorkflowWebhookTrigger
|
||||
from models.web import PinnedConversation, SavedMessage
|
||||
from models.workflow import (
|
||||
ConversationVariable,
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
)
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="app_deletion", bind=True, max_retries=3)
|
||||
def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
|
||||
logger.info(click.style(f"Start deleting app and related data: {tenant_id}:{app_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
try:
|
||||
# Delete related data
|
||||
_delete_app_model_configs(tenant_id, app_id)
|
||||
_delete_app_site(tenant_id, app_id)
|
||||
_delete_app_mcp_servers(tenant_id, app_id)
|
||||
_delete_app_api_tokens(tenant_id, app_id)
|
||||
_delete_installed_apps(tenant_id, app_id)
|
||||
_delete_recommended_apps(tenant_id, app_id)
|
||||
_delete_app_annotation_data(tenant_id, app_id)
|
||||
_delete_app_dataset_joins(tenant_id, app_id)
|
||||
_delete_app_workflows(tenant_id, app_id)
|
||||
_delete_app_workflow_runs(tenant_id, app_id)
|
||||
_delete_app_workflow_node_executions(tenant_id, app_id)
|
||||
_delete_app_workflow_app_logs(tenant_id, app_id)
|
||||
_delete_app_conversations(tenant_id, app_id)
|
||||
_delete_app_messages(tenant_id, app_id)
|
||||
_delete_workflow_tool_providers(tenant_id, app_id)
|
||||
_delete_app_tag_bindings(tenant_id, app_id)
|
||||
_delete_end_users(tenant_id, app_id)
|
||||
_delete_trace_app_configs(tenant_id, app_id)
|
||||
_delete_conversation_variables(app_id=app_id)
|
||||
_delete_draft_variables(app_id)
|
||||
_delete_app_triggers(tenant_id, app_id)
|
||||
_delete_workflow_plugin_triggers(tenant_id, app_id)
|
||||
_delete_workflow_webhook_triggers(tenant_id, app_id)
|
||||
_delete_workflow_schedule_plans(tenant_id, app_id)
|
||||
_delete_workflow_trigger_logs(tenant_id, app_id)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green"))
|
||||
except SQLAlchemyError as e:
|
||||
logger.exception(click.style(f"Database error occurred while deleting app {app_id} and related data", fg="red"))
|
||||
raise self.retry(exc=e, countdown=60) # Retry after 60 seconds
|
||||
except Exception as e:
|
||||
logger.exception(click.style(f"Error occurred while deleting app {app_id} and related data", fg="red"))
|
||||
raise self.retry(exc=e, countdown=60) # Retry after 60 seconds
|
||||
|
||||
|
||||
def _delete_app_model_configs(tenant_id: str, app_id: str):
|
||||
def del_model_config(model_config_id: str):
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_model_configs where app_id=:app_id limit 1000""",
|
||||
{"app_id": app_id},
|
||||
del_model_config,
|
||||
"app model config",
|
||||
)
|
||||
|
||||
|
||||
def _delete_app_site(tenant_id: str, app_id: str):
|
||||
def del_site(site_id: str):
|
||||
db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from sites where app_id=:app_id limit 1000""",
|
||||
{"app_id": app_id},
|
||||
del_site,
|
||||
"site",
|
||||
)
|
||||
|
||||
|
||||
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
|
||||
def del_mcp_server(mcp_server_id: str):
|
||||
db.session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
|
||||
{"app_id": app_id},
|
||||
del_mcp_server,
|
||||
"app mcp server",
|
||||
)
|
||||
|
||||
|
||||
def _delete_app_api_tokens(tenant_id: str, app_id: str):
|
||||
def del_api_token(api_token_id: str):
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from api_tokens where app_id=:app_id limit 1000""",
|
||||
{"app_id": app_id},
|
||||
del_api_token,
|
||||
"api token",
|
||||
)
|
||||
|
||||
|
||||
def _delete_installed_apps(tenant_id: str, app_id: str):
|
||||
def del_installed_app(installed_app_id: str):
|
||||
db.session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_installed_app,
|
||||
"installed app",
|
||||
)
|
||||
|
||||
|
||||
def _delete_recommended_apps(tenant_id: str, app_id: str):
|
||||
def del_recommended_app(recommended_app_id: str):
|
||||
db.session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from recommended_apps where app_id=:app_id limit 1000""",
|
||||
{"app_id": app_id},
|
||||
del_recommended_app,
|
||||
"recommended app",
|
||||
)
|
||||
|
||||
|
||||
def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
||||
def del_annotation_hit_history(annotation_hit_history_id: str):
|
||||
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_annotation_hit_histories where app_id=:app_id limit 1000""",
|
||||
{"app_id": app_id},
|
||||
del_annotation_hit_history,
|
||||
"annotation hit history",
|
||||
)
|
||||
|
||||
def del_annotation_setting(annotation_setting_id: str):
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_annotation_settings where app_id=:app_id limit 1000""",
|
||||
{"app_id": app_id},
|
||||
del_annotation_setting,
|
||||
"annotation setting",
|
||||
)
|
||||
|
||||
|
||||
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
|
||||
def del_dataset_join(dataset_join_id: str):
|
||||
db.session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_dataset_joins where app_id=:app_id limit 1000""",
|
||||
{"app_id": app_id},
|
||||
del_dataset_join,
|
||||
"dataset join",
|
||||
)
|
||||
|
||||
|
||||
def _delete_app_workflows(tenant_id: str, app_id: str):
|
||||
def del_workflow(workflow_id: str):
|
||||
db.session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_workflow,
|
||||
"workflow",
|
||||
)
|
||||
|
||||
|
||||
def _delete_app_workflow_runs(tenant_id: str, app_id: str):
|
||||
"""Delete all workflow runs for an app using the service repository."""
|
||||
session_maker = sessionmaker(bind=db.engine)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
deleted_count = workflow_run_repo.delete_runs_by_app(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
logger.info("Deleted %s workflow runs for app %s", deleted_count, app_id)
|
||||
|
||||
|
||||
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
||||
"""Delete all workflow node executions for an app using the service repository."""
|
||||
session_maker = sessionmaker(bind=db.engine)
|
||||
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
|
||||
|
||||
deleted_count = node_execution_repo.delete_executions_by_app(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
logger.info("Deleted %s workflow node executions for app %s", deleted_count, app_id)
|
||||
|
||||
|
||||
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||
def del_workflow_app_log(workflow_app_log_id: str):
|
||||
db.session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_workflow_app_log,
|
||||
"workflow app log",
|
||||
)
|
||||
|
||||
|
||||
def _delete_app_conversations(tenant_id: str, app_id: str):
|
||||
def del_conversation(conversation_id: str):
|
||||
db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from conversations where app_id=:app_id limit 1000""",
|
||||
{"app_id": app_id},
|
||||
del_conversation,
|
||||
"conversation",
|
||||
)
|
||||
|
||||
|
||||
def _delete_conversation_variables(*, app_id: str):
|
||||
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
|
||||
with db.engine.connect() as conn:
|
||||
conn.execute(stmt)
|
||||
conn.commit()
|
||||
logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green"))
|
||||
|
||||
|
||||
def _delete_app_messages(tenant_id: str, app_id: str):
|
||||
def del_message(message_id: str):
|
||||
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
|
||||
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
|
||||
db.session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
|
||||
db.session.query(Message).where(Message.id == message_id).delete()
|
||||
|
||||
_delete_records(
|
||||
"""select id from messages where app_id=:app_id limit 1000""",
|
||||
{"app_id": app_id},
|
||||
del_message,
|
||||
"message",
|
||||
)
|
||||
|
||||
|
||||
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
|
||||
def del_tool_provider(tool_provider_id: str):
|
||||
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from tool_workflow_providers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_tool_provider,
|
||||
"tool workflow provider",
|
||||
)
|
||||
|
||||
|
||||
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
|
||||
def del_tag_binding(tag_binding_id: str):
|
||||
db.session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_tag_binding,
|
||||
"tag binding",
|
||||
)
|
||||
|
||||
|
||||
def _delete_end_users(tenant_id: str, app_id: str):
|
||||
def del_end_user(end_user_id: str):
|
||||
db.session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_end_user,
|
||||
"end user",
|
||||
)
|
||||
|
||||
|
||||
def _delete_trace_app_configs(tenant_id: str, app_id: str):
|
||||
def del_trace_app_config(trace_app_config_id: str):
|
||||
db.session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from trace_app_config where app_id=:app_id limit 1000""",
|
||||
{"app_id": app_id},
|
||||
del_trace_app_config,
|
||||
"trace app config",
|
||||
)
|
||||
|
||||
|
||||
def _delete_draft_variables(app_id: str):
|
||||
"""Delete all workflow draft variables for an app in batches."""
|
||||
return delete_draft_variables_batch(app_id, batch_size=1000)
|
||||
|
||||
|
||||
def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
||||
"""
|
||||
Delete draft variables for an app in batches.
|
||||
|
||||
This function now handles cleanup of associated Offload data including:
|
||||
- WorkflowDraftVariableFile records
|
||||
- UploadFile records
|
||||
- Object storage files
|
||||
|
||||
Args:
|
||||
app_id: The ID of the app whose draft variables should be deleted
|
||||
batch_size: Number of records to delete per batch
|
||||
|
||||
Returns:
|
||||
Total number of records deleted
|
||||
"""
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size must be positive")
|
||||
|
||||
total_deleted = 0
|
||||
total_files_deleted = 0
|
||||
|
||||
while True:
|
||||
with db.engine.begin() as conn:
|
||||
# Get a batch of draft variable IDs along with their file_ids
|
||||
query_sql = """
|
||||
SELECT id, file_id FROM workflow_draft_variables
|
||||
WHERE app_id = :app_id
|
||||
LIMIT :batch_size
|
||||
"""
|
||||
result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
|
||||
|
||||
rows = list(result)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
draft_var_ids = [row[0] for row in rows]
|
||||
file_ids = [row[1] for row in rows if row[1] is not None]
|
||||
|
||||
# Clean up associated Offload data first
|
||||
if file_ids:
|
||||
files_deleted = _delete_draft_variable_offload_data(conn, file_ids)
|
||||
total_files_deleted += files_deleted
|
||||
|
||||
# Delete the draft variables
|
||||
delete_sql = """
|
||||
DELETE FROM workflow_draft_variables
|
||||
WHERE id IN :ids
|
||||
"""
|
||||
deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)})
|
||||
batch_deleted = deleted_result.rowcount
|
||||
total_deleted += batch_deleted
|
||||
|
||||
logger.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green"))
|
||||
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Deleted {total_deleted} total draft variables for app {app_id}. "
|
||||
f"Cleaned up {total_files_deleted} total associated files.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
return total_deleted
|
||||
|
||||
|
||||
def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
|
||||
"""
|
||||
Delete Offload data associated with WorkflowDraftVariable file_ids.
|
||||
|
||||
This function:
|
||||
1. Finds WorkflowDraftVariableFile records by file_ids
|
||||
2. Deletes associated files from object storage
|
||||
3. Deletes UploadFile records
|
||||
4. Deletes WorkflowDraftVariableFile records
|
||||
|
||||
Args:
|
||||
conn: Database connection
|
||||
file_ids: List of WorkflowDraftVariableFile IDs
|
||||
|
||||
Returns:
|
||||
Number of files cleaned up
|
||||
"""
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
if not file_ids:
|
||||
return 0
|
||||
|
||||
files_deleted = 0
|
||||
|
||||
try:
|
||||
# Get WorkflowDraftVariableFile records and their associated UploadFile keys
|
||||
query_sql = """
|
||||
SELECT wdvf.id, uf.key, uf.id as upload_file_id
|
||||
FROM workflow_draft_variable_files wdvf
|
||||
JOIN upload_files uf ON wdvf.upload_file_id = uf.id
|
||||
WHERE wdvf.id IN :file_ids
|
||||
"""
|
||||
result = conn.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
|
||||
file_records = list(result)
|
||||
|
||||
# Delete from object storage and collect upload file IDs
|
||||
upload_file_ids = []
|
||||
for _, storage_key, upload_file_id in file_records:
|
||||
try:
|
||||
storage.delete(storage_key)
|
||||
upload_file_ids.append(upload_file_id)
|
||||
files_deleted += 1
|
||||
except Exception:
|
||||
logging.exception("Failed to delete storage object %s", storage_key)
|
||||
# Continue with database cleanup even if storage deletion fails
|
||||
upload_file_ids.append(upload_file_id)
|
||||
|
||||
# Delete UploadFile records
|
||||
if upload_file_ids:
|
||||
delete_upload_files_sql = """
|
||||
DELETE FROM upload_files
|
||||
WHERE id IN :upload_file_ids
|
||||
"""
|
||||
conn.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
|
||||
|
||||
# Delete WorkflowDraftVariableFile records
|
||||
delete_variable_files_sql = """
|
||||
DELETE FROM workflow_draft_variable_files
|
||||
WHERE id IN :file_ids
|
||||
"""
|
||||
conn.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
|
||||
|
||||
except Exception:
|
||||
logging.exception("Error deleting draft variable offload data:")
|
||||
# Don't raise, as we want to continue with the main deletion process
|
||||
|
||||
return files_deleted
|
||||
|
||||
|
||||
def _delete_app_triggers(tenant_id: str, app_id: str):
|
||||
def del_app_trigger(trigger_id: str):
|
||||
db.session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_app_trigger,
|
||||
"app trigger",
|
||||
)
|
||||
|
||||
|
||||
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
|
||||
def del_plugin_trigger(trigger_id: str):
|
||||
db.session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_plugin_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_plugin_trigger,
|
||||
"workflow plugin trigger",
|
||||
)
|
||||
|
||||
|
||||
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
|
||||
def del_webhook_trigger(trigger_id: str):
|
||||
db.session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_webhook_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_webhook_trigger,
|
||||
"workflow webhook trigger",
|
||||
)
|
||||
|
||||
|
||||
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
|
||||
def del_schedule_plan(plan_id: str):
|
||||
db.session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_schedule_plan,
|
||||
"workflow schedule plan",
|
||||
)
|
||||
|
||||
|
||||
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
|
||||
def del_trigger_log(log_id: str):
|
||||
db.session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_trigger_log,
|
||||
"workflow trigger log",
|
||||
)
|
||||
|
||||
|
||||
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
|
||||
while True:
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query_sql), params)
|
||||
if rs.rowcount == 0:
|
||||
break
|
||||
|
||||
for i in rs:
|
||||
record_id = str(i.id)
|
||||
try:
|
||||
delete_func(record_id)
|
||||
db.session.commit()
|
||||
logger.info(click.style(f"Deleted {name} {record_id}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("Error occurred while deleting %s %s", name, record_id)
|
||||
continue
|
||||
rs.close()
|
||||
76
dify/api/tasks/remove_document_from_index_task.py
Normal file
76
dify/api/tasks/remove_document_from_index_task.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Document, DocumentSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def remove_document_from_index_task(document_id: str):
|
||||
"""
|
||||
Async Remove document from index
|
||||
:param document_id: document id
|
||||
|
||||
Usage: remove_document_from_index.delay(document_id)
|
||||
"""
|
||||
logger.info(click.style(f"Start remove document segments from index: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
if document.indexing_status != "completed":
|
||||
logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
indexing_cache_key = f"document_{document.id}_indexing"
|
||||
|
||||
try:
|
||||
dataset = document.dataset
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
if index_node_ids:
|
||||
try:
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
|
||||
except Exception:
|
||||
logger.exception("clean dataset %s from index failed", dataset.id)
|
||||
# update segment to disable
|
||||
db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
|
||||
{
|
||||
DocumentSegment.enabled: False,
|
||||
DocumentSegment.disabled_at: naive_utc_now(),
|
||||
DocumentSegment.disabled_by: document.disabled_by,
|
||||
DocumentSegment.updated_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Document removed from index: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("remove document from index failed")
|
||||
if not document.archived:
|
||||
document.enabled = True
|
||||
db.session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
125
dify/api/tasks/retry_document_indexing_task.py
Normal file
125
dify/api/tasks/retry_document_indexing_task.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Tenant
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.feature_service import FeatureService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_id: str):
|
||||
"""
|
||||
Async process document
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
:param user_id:
|
||||
|
||||
Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id)
|
||||
"""
|
||||
start_at = time.perf_counter()
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
return
|
||||
user = db.session.query(Account).where(Account.id == user_id).first()
|
||||
if not user:
|
||||
logger.info(click.style(f"User not found: {user_id}", fg="red"))
|
||||
return
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError("Tenant not found")
|
||||
user.current_tenant = tenant
|
||||
|
||||
for document_id in document_ids:
|
||||
retry_indexing_cache_key = f"document_{document_id}_is_retried"
|
||||
# check document limit
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription."
|
||||
)
|
||||
except Exception as e:
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start retry document: {document_id}", fg="green"))
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
|
||||
return
|
||||
try:
|
||||
# clean old data
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
if dataset.runtime_mode == "rag_pipeline":
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.retry_error_document(dataset, document, user)
|
||||
else:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(ex)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
logger.exception("retry_document_indexing_task failed, document_id: %s", document_id)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
95
dify/api/tasks/sync_website_document_indexing_task.py
Normal file
95
dify/api/tasks/sync_website_document_indexing_task.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
||||
"""
|
||||
Async process document
|
||||
:param dataset_id:
|
||||
:param document_id:
|
||||
|
||||
Usage: sync_website_document_indexing_task.delay(dataset_id, document_id)
|
||||
"""
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset is None:
|
||||
raise ValueError("Dataset not found")
|
||||
|
||||
sync_indexing_cache_key = f"document_{document_id}_is_sync"
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription."
|
||||
)
|
||||
except Exception as e:
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
|
||||
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
|
||||
return
|
||||
try:
|
||||
# clean old data
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(ex)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green"))
|
||||
521
dify/api/tasks/trigger_processing_tasks.py
Normal file
521
dify/api/tasks/trigger_processing_tasks.py
Normal file
@@ -0,0 +1,521 @@
|
||||
"""
|
||||
Celery tasks for async trigger processing.
|
||||
|
||||
These tasks handle trigger workflow execution asynchronously
|
||||
to avoid blocking the main request thread.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import TriggerInvokeEventResponse
|
||||
from core.plugin.impl.exc import PluginInvokeError
|
||||
from core.trigger.debug.event_bus import TriggerDebugEventBus
|
||||
from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key
|
||||
from core.trigger.entities.entities import TriggerProviderEntity
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.enums import NodeType, WorkflowExecutionStatus
|
||||
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from extensions.ext_database import db
|
||||
from models.enums import (
|
||||
AppTriggerType,
|
||||
CreatorUserRole,
|
||||
WorkflowRunTriggeredFrom,
|
||||
WorkflowTriggerStatus,
|
||||
)
|
||||
from models.model import EndUser
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerSubscription, WorkflowPluginTrigger, WorkflowTriggerLog
|
||||
from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowRun
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
|
||||
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
|
||||
from services.workflow.entities import PluginTriggerData, PluginTriggerDispatchData, PluginTriggerMetadata
|
||||
from services.workflow.queue_dispatcher import QueueDispatcherManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Use workflow queue for trigger processing
|
||||
TRIGGER_QUEUE = "triggered_workflow_dispatcher"
|
||||
|
||||
|
||||
def dispatch_trigger_debug_event(
|
||||
events: list[str],
|
||||
user_id: str,
|
||||
timestamp: int,
|
||||
request_id: str,
|
||||
subscription: TriggerSubscription,
|
||||
) -> int:
|
||||
debug_dispatched = 0
|
||||
try:
|
||||
for event_name in events:
|
||||
pool_key: str = build_plugin_pool_key(
|
||||
name=event_name,
|
||||
tenant_id=subscription.tenant_id,
|
||||
subscription_id=subscription.id,
|
||||
provider_id=subscription.provider_id,
|
||||
)
|
||||
trigger_debug_event: PluginTriggerDebugEvent = PluginTriggerDebugEvent(
|
||||
timestamp=timestamp,
|
||||
user_id=user_id,
|
||||
name=event_name,
|
||||
request_id=request_id,
|
||||
subscription_id=subscription.id,
|
||||
provider_id=subscription.provider_id,
|
||||
)
|
||||
debug_dispatched += TriggerDebugEventBus.dispatch(
|
||||
tenant_id=subscription.tenant_id,
|
||||
event=trigger_debug_event,
|
||||
pool_key=pool_key,
|
||||
)
|
||||
logger.debug(
|
||||
"Trigger debug dispatched %d sessions to pool %s for event %s for subscription %s provider %s",
|
||||
debug_dispatched,
|
||||
pool_key,
|
||||
event_name,
|
||||
subscription.id,
|
||||
subscription.provider_id,
|
||||
)
|
||||
return debug_dispatched
|
||||
except Exception:
|
||||
logger.exception("Failed to dispatch to debug sessions")
|
||||
return 0
|
||||
|
||||
|
||||
def _get_latest_workflows_by_app_ids(
|
||||
session: Session, subscribers: Sequence[WorkflowPluginTrigger]
|
||||
) -> Mapping[str, Workflow]:
|
||||
"""Get the latest workflows by app_ids"""
|
||||
workflow_query = (
|
||||
select(Workflow.app_id, func.max(Workflow.created_at).label("max_created_at"))
|
||||
.where(
|
||||
Workflow.app_id.in_({t.app_id for t in subscribers}),
|
||||
Workflow.version != Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.group_by(Workflow.app_id)
|
||||
.subquery()
|
||||
)
|
||||
workflows = session.scalars(
|
||||
select(Workflow).join(
|
||||
workflow_query,
|
||||
(Workflow.app_id == workflow_query.c.app_id) & (Workflow.created_at == workflow_query.c.max_created_at),
|
||||
)
|
||||
).all()
|
||||
return {w.app_id: w for w in workflows}
|
||||
|
||||
|
||||
def _record_trigger_failure_log(
|
||||
*,
|
||||
session: Session,
|
||||
workflow: Workflow,
|
||||
plugin_trigger: WorkflowPluginTrigger,
|
||||
subscription: TriggerSubscription,
|
||||
trigger_metadata: PluginTriggerMetadata,
|
||||
end_user: EndUser | None,
|
||||
error_message: str,
|
||||
event_name: str,
|
||||
request_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Persist a workflow run, workflow app log, and trigger log entry for failed trigger invocations.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
if end_user:
|
||||
created_by_role = CreatorUserRole.END_USER
|
||||
created_by = end_user.id
|
||||
else:
|
||||
created_by_role = CreatorUserRole.ACCOUNT
|
||||
created_by = subscription.user_id
|
||||
|
||||
failure_inputs = {
|
||||
"event_name": event_name,
|
||||
"subscription_id": subscription.id,
|
||||
"request_id": request_id,
|
||||
"plugin_trigger_id": plugin_trigger.id,
|
||||
}
|
||||
|
||||
workflow_run = WorkflowRun(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
type=workflow.type,
|
||||
triggered_from=WorkflowRunTriggeredFrom.PLUGIN.value,
|
||||
version=workflow.version,
|
||||
graph=workflow.graph,
|
||||
inputs=json.dumps(failure_inputs),
|
||||
status=WorkflowExecutionStatus.FAILED.value,
|
||||
outputs="{}",
|
||||
error=error_message,
|
||||
elapsed_time=0.0,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
created_by_role=created_by_role.value,
|
||||
created_by=created_by,
|
||||
created_at=now,
|
||||
finished_at=now,
|
||||
exceptions_count=0,
|
||||
)
|
||||
session.add(workflow_run)
|
||||
session.flush()
|
||||
|
||||
workflow_app_log = WorkflowAppLog(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value,
|
||||
created_by_role=created_by_role.value,
|
||||
created_by=created_by,
|
||||
)
|
||||
session.add(workflow_app_log)
|
||||
|
||||
dispatcher = QueueDispatcherManager.get_dispatcher(subscription.tenant_id)
|
||||
queue_name = dispatcher.get_queue_name()
|
||||
|
||||
trigger_data = PluginTriggerData(
|
||||
app_id=plugin_trigger.app_id,
|
||||
tenant_id=subscription.tenant_id,
|
||||
workflow_id=workflow.id,
|
||||
root_node_id=plugin_trigger.node_id,
|
||||
inputs={},
|
||||
trigger_metadata=trigger_metadata,
|
||||
plugin_id=subscription.provider_id,
|
||||
endpoint_id=subscription.endpoint_id,
|
||||
)
|
||||
|
||||
trigger_log = WorkflowTriggerLog(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
root_node_id=plugin_trigger.node_id,
|
||||
trigger_metadata=trigger_metadata.model_dump_json(),
|
||||
trigger_type=AppTriggerType.TRIGGER_PLUGIN,
|
||||
trigger_data=trigger_data.model_dump_json(),
|
||||
inputs=json.dumps({}),
|
||||
status=WorkflowTriggerStatus.FAILED,
|
||||
error=error_message,
|
||||
queue_name=queue_name,
|
||||
retry_count=0,
|
||||
created_by_role=created_by_role.value,
|
||||
created_by=created_by,
|
||||
triggered_at=now,
|
||||
finished_at=now,
|
||||
elapsed_time=0.0,
|
||||
total_tokens=0,
|
||||
outputs=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
session.add(trigger_log)
|
||||
session.commit()
|
||||
|
||||
|
||||
def dispatch_triggered_workflow(
|
||||
user_id: str,
|
||||
subscription: TriggerSubscription,
|
||||
event_name: str,
|
||||
request_id: str,
|
||||
) -> int:
|
||||
"""Process triggered workflows.
|
||||
|
||||
Args:
|
||||
subscription: The trigger subscription
|
||||
event: The trigger entity that was activated
|
||||
request_id: The ID of the stored request in storage system
|
||||
"""
|
||||
request = TriggerHttpRequestCachingService.get_request(request_id)
|
||||
payload = TriggerHttpRequestCachingService.get_payload(request_id)
|
||||
|
||||
subscribers: list[WorkflowPluginTrigger] = TriggerSubscriptionOperatorService.get_subscriber_triggers(
|
||||
tenant_id=subscription.tenant_id, subscription_id=subscription.id, event_name=event_name
|
||||
)
|
||||
if not subscribers:
|
||||
logger.warning(
|
||||
"No workflows found for trigger event '%s' in subscription '%s'",
|
||||
event_name,
|
||||
subscription.id,
|
||||
)
|
||||
return 0
|
||||
|
||||
dispatched_count = 0
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
|
||||
)
|
||||
trigger_entity: TriggerProviderEntity = provider_controller.entity
|
||||
with Session(db.engine) as session:
|
||||
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
|
||||
|
||||
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.TRIGGER,
|
||||
tenant_id=subscription.tenant_id,
|
||||
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
|
||||
user_id=user_id,
|
||||
)
|
||||
for plugin_trigger in subscribers:
|
||||
# Get workflow from mapping
|
||||
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
|
||||
if not workflow:
|
||||
logger.error(
|
||||
"Workflow not found for app %s",
|
||||
plugin_trigger.app_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# Find the trigger node in the workflow
|
||||
event_node = None
|
||||
for node_id, node_config in workflow.walk_nodes(NodeType.TRIGGER_PLUGIN):
|
||||
if node_id == plugin_trigger.node_id:
|
||||
event_node = node_config
|
||||
break
|
||||
|
||||
if not event_node:
|
||||
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
|
||||
continue
|
||||
|
||||
# invoke trigger
|
||||
trigger_metadata = PluginTriggerMetadata(
|
||||
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
|
||||
endpoint_id=subscription.endpoint_id,
|
||||
provider_id=subscription.provider_id,
|
||||
event_name=event_name,
|
||||
icon_filename=trigger_entity.identity.icon or "",
|
||||
icon_dark_filename=trigger_entity.identity.icon_dark or "",
|
||||
)
|
||||
|
||||
# consume quota before invoking trigger
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
|
||||
logger.info(
|
||||
"Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id
|
||||
)
|
||||
return 0
|
||||
|
||||
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
|
||||
invoke_response: TriggerInvokeEventResponse | None = None
|
||||
try:
|
||||
invoke_response = TriggerManager.invoke_trigger_event(
|
||||
tenant_id=subscription.tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=TriggerProviderID(subscription.provider_id),
|
||||
event_name=event_name,
|
||||
parameters=node_data.resolve_parameters(
|
||||
parameter_schemas=provider_controller.get_event_parameters(event_name=event_name)
|
||||
),
|
||||
credentials=subscription.credentials,
|
||||
credential_type=CredentialType.of(subscription.credential_type),
|
||||
subscription=subscription.to_entity(),
|
||||
request=request,
|
||||
payload=payload,
|
||||
)
|
||||
except PluginInvokeError as e:
|
||||
quota_charge.refund()
|
||||
|
||||
error_message = e.to_user_friendly_error(plugin_name=trigger_entity.identity.name)
|
||||
try:
|
||||
end_user = end_users.get(plugin_trigger.app_id)
|
||||
_record_trigger_failure_log(
|
||||
session=session,
|
||||
workflow=workflow,
|
||||
plugin_trigger=plugin_trigger,
|
||||
subscription=subscription,
|
||||
trigger_metadata=trigger_metadata,
|
||||
end_user=end_user,
|
||||
error_message=error_message,
|
||||
event_name=event_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to record trigger failure log for app %s",
|
||||
plugin_trigger.app_id,
|
||||
)
|
||||
continue
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
|
||||
logger.exception(
|
||||
"Failed to invoke trigger event for app %s",
|
||||
plugin_trigger.app_id,
|
||||
)
|
||||
continue
|
||||
|
||||
if invoke_response is not None and invoke_response.cancelled:
|
||||
quota_charge.refund()
|
||||
|
||||
logger.info(
|
||||
"Trigger ignored for app %s with trigger event %s",
|
||||
plugin_trigger.app_id,
|
||||
event_name,
|
||||
)
|
||||
continue
|
||||
|
||||
# Create trigger data for async execution
|
||||
trigger_data = PluginTriggerData(
|
||||
app_id=plugin_trigger.app_id,
|
||||
tenant_id=subscription.tenant_id,
|
||||
workflow_id=workflow.id,
|
||||
root_node_id=plugin_trigger.node_id,
|
||||
plugin_id=subscription.provider_id,
|
||||
endpoint_id=subscription.endpoint_id,
|
||||
inputs=invoke_response.variables,
|
||||
trigger_metadata=trigger_metadata,
|
||||
)
|
||||
|
||||
# Trigger async workflow
|
||||
try:
|
||||
end_user = end_users.get(plugin_trigger.app_id)
|
||||
if not end_user:
|
||||
raise ValueError(f"End user not found for app {plugin_trigger.app_id}")
|
||||
|
||||
AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data)
|
||||
dispatched_count += 1
|
||||
logger.info(
|
||||
"Triggered workflow for app %s with trigger event %s",
|
||||
plugin_trigger.app_id,
|
||||
event_name,
|
||||
)
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
|
||||
logger.exception(
|
||||
"Failed to trigger workflow for app %s",
|
||||
plugin_trigger.app_id,
|
||||
)
|
||||
|
||||
return dispatched_count
|
||||
|
||||
|
||||
def dispatch_triggered_workflows(
|
||||
user_id: str,
|
||||
events: list[str],
|
||||
subscription: TriggerSubscription,
|
||||
request_id: str,
|
||||
) -> int:
|
||||
dispatched_count = 0
|
||||
for event_name in events:
|
||||
try:
|
||||
dispatched_count += dispatch_triggered_workflow(
|
||||
user_id=user_id,
|
||||
subscription=subscription,
|
||||
event_name=event_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to dispatch trigger '%s' for subscription %s and provider %s. Continuing...",
|
||||
event_name,
|
||||
subscription.id,
|
||||
subscription.provider_id,
|
||||
)
|
||||
# Continue processing other triggers even if one fails
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Completed async trigger dispatching: processed %d/%d triggers for subscription %s and provider %s",
|
||||
dispatched_count,
|
||||
len(events),
|
||||
subscription.id,
|
||||
subscription.provider_id,
|
||||
)
|
||||
return dispatched_count
|
||||
|
||||
|
||||
@shared_task(queue=TRIGGER_QUEUE)
|
||||
def dispatch_triggered_workflows_async(
|
||||
dispatch_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Dispatch triggers asynchronously.
|
||||
|
||||
Args:
|
||||
endpoint_id: Endpoint ID
|
||||
provider_id: Provider ID
|
||||
subscription_id: Subscription ID
|
||||
timestamp: Timestamp of the event
|
||||
triggers: List of triggers to dispatch
|
||||
request_id: Unique ID of the stored request
|
||||
|
||||
Returns:
|
||||
dict: Execution result with status and dispatched trigger count
|
||||
"""
|
||||
dispatch_params: PluginTriggerDispatchData = PluginTriggerDispatchData.model_validate(dispatch_data)
|
||||
user_id = dispatch_params.user_id
|
||||
tenant_id = dispatch_params.tenant_id
|
||||
endpoint_id = dispatch_params.endpoint_id
|
||||
provider_id = dispatch_params.provider_id
|
||||
subscription_id = dispatch_params.subscription_id
|
||||
timestamp = dispatch_params.timestamp
|
||||
events = dispatch_params.events
|
||||
request_id = dispatch_params.request_id
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Starting trigger dispatching uid=%s, endpoint=%s, events=%s, req_id=%s, sub_id=%s, provider_id=%s",
|
||||
user_id,
|
||||
endpoint_id,
|
||||
events,
|
||||
request_id,
|
||||
subscription_id,
|
||||
provider_id,
|
||||
)
|
||||
|
||||
subscription: TriggerSubscription | None = TriggerProviderService.get_subscription_by_id(
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
if not subscription:
|
||||
logger.error("Subscription not found: %s", subscription_id)
|
||||
return {"status": "failed", "error": "Subscription not found"}
|
||||
|
||||
workflow_dispatched = dispatch_triggered_workflows(
|
||||
user_id=user_id,
|
||||
events=events,
|
||||
subscription=subscription,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
debug_dispatched = dispatch_trigger_debug_event(
|
||||
events=events,
|
||||
user_id=user_id,
|
||||
timestamp=timestamp,
|
||||
request_id=request_id,
|
||||
subscription=subscription,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"total_count": len(events),
|
||||
"workflows": workflow_dispatched,
|
||||
"debug_events": debug_dispatched,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Error in async trigger dispatching for endpoint %s data %s for subscription %s and provider %s",
|
||||
endpoint_id,
|
||||
dispatch_data,
|
||||
subscription_id,
|
||||
provider_id,
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
}
|
||||
119
dify/api/tasks/trigger_subscription_refresh_tasks.py
Normal file
119
dify/api/tasks/trigger_subscription_refresh_tasks.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.utils.locks import build_trigger_refresh_lock_key
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.trigger import TriggerSubscription
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now_ts() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def _load_subscription(session: Session, tenant_id: str, subscription_id: str) -> TriggerSubscription | None:
|
||||
return session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
|
||||
|
||||
def _refresh_oauth_if_expired(tenant_id: str, subscription: TriggerSubscription, now: int) -> None:
|
||||
threshold_seconds: int = int(dify_config.TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS)
|
||||
if (
|
||||
subscription.credential_expires_at != -1
|
||||
and int(subscription.credential_expires_at) <= now + threshold_seconds
|
||||
and CredentialType.of(subscription.credential_type) == CredentialType.OAUTH2
|
||||
):
|
||||
logger.info(
|
||||
"Refreshing OAuth token: tenant=%s subscription_id=%s expires_at=%s now=%s",
|
||||
tenant_id,
|
||||
subscription.id,
|
||||
subscription.credential_expires_at,
|
||||
now,
|
||||
)
|
||||
try:
|
||||
result: Mapping[str, Any] = TriggerProviderService.refresh_oauth_token(
|
||||
tenant_id=tenant_id, subscription_id=subscription.id
|
||||
)
|
||||
logger.info(
|
||||
"OAuth token refreshed: tenant=%s subscription_id=%s result=%s", tenant_id, subscription.id, result
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("OAuth refresh failed: tenant=%s subscription_id=%s", tenant_id, subscription.id)
|
||||
|
||||
|
||||
def _refresh_subscription_if_expired(
|
||||
tenant_id: str,
|
||||
subscription: TriggerSubscription,
|
||||
now: int,
|
||||
) -> None:
|
||||
threshold_seconds: int = int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS)
|
||||
if subscription.expires_at == -1 or int(subscription.expires_at) > now + threshold_seconds:
|
||||
logger.debug(
|
||||
"Subscription not due: tenant=%s subscription_id=%s expires_at=%s now=%s threshold=%s",
|
||||
tenant_id,
|
||||
subscription.id,
|
||||
subscription.expires_at,
|
||||
now,
|
||||
threshold_seconds,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
result: Mapping[str, Any] = TriggerProviderService.refresh_subscription(
|
||||
tenant_id=tenant_id, subscription_id=subscription.id, now=now
|
||||
)
|
||||
logger.info(
|
||||
"Subscription refreshed: tenant=%s subscription_id=%s result=%s",
|
||||
tenant_id,
|
||||
subscription.id,
|
||||
result.get("result"),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Subscription refresh failed: tenant=%s id=%s", tenant_id, subscription.id)
|
||||
|
||||
|
||||
@shared_task(queue="trigger_refresh_executor")
|
||||
def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None:
|
||||
"""Refresh a trigger subscription if needed, guarded by a Redis in-flight lock."""
|
||||
lock_key: str = build_trigger_refresh_lock_key(tenant_id, subscription_id)
|
||||
if not redis_client.get(lock_key):
|
||||
logger.debug("Refresh lock missing, skip: %s", lock_key)
|
||||
return
|
||||
|
||||
logger.info("Begin subscription refresh: tenant=%s id=%s", tenant_id, subscription_id)
|
||||
try:
|
||||
now: int = _now_ts()
|
||||
with Session(db.engine) as session:
|
||||
subscription: TriggerSubscription | None = _load_subscription(session, tenant_id, subscription_id)
|
||||
|
||||
if not subscription:
|
||||
logger.warning("Subscription not found: tenant=%s id=%s", tenant_id, subscription_id)
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
"Loaded subscription: tenant=%s id=%s cred_exp=%s sub_exp=%s now=%s",
|
||||
tenant_id,
|
||||
subscription.id,
|
||||
subscription.credential_expires_at,
|
||||
subscription.expires_at,
|
||||
now,
|
||||
)
|
||||
|
||||
_refresh_oauth_if_expired(tenant_id=tenant_id, subscription=subscription, now=now)
|
||||
_refresh_subscription_if_expired(tenant_id=tenant_id, subscription=subscription, now=now)
|
||||
finally:
|
||||
try:
|
||||
redis_client.delete(lock_key)
|
||||
logger.debug("Lock released: %s", lock_key)
|
||||
except Exception:
|
||||
# Best-effort lock cleanup
|
||||
logger.warning("Failed to release lock: %s", lock_key, exc_info=True)
|
||||
32
dify/api/tasks/workflow_cfs_scheduler/cfs_scheduler.py
Normal file
32
dify/api/tasks/workflow_cfs_scheduler/cfs_scheduler.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
|
||||
from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand
|
||||
from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue
|
||||
|
||||
|
||||
class AsyncWorkflowCFSPlanEntity(WorkflowScheduleCFSPlanEntity):
|
||||
"""
|
||||
Trigger workflow CFS plan entity.
|
||||
"""
|
||||
|
||||
queue: AsyncWorkflowQueue
|
||||
|
||||
|
||||
class AsyncWorkflowCFSPlanScheduler(CFSPlanScheduler):
|
||||
"""
|
||||
Trigger workflow CFS plan scheduler.
|
||||
"""
|
||||
|
||||
plan: AsyncWorkflowCFSPlanEntity
|
||||
|
||||
def can_schedule(self) -> SchedulerCommand:
|
||||
"""
|
||||
Check if the workflow can be scheduled.
|
||||
"""
|
||||
if self.plan.queue in [AsyncWorkflowQueue.PROFESSIONAL_QUEUE, AsyncWorkflowQueue.TEAM_QUEUE]:
|
||||
"""
|
||||
permitted all paid users to schedule the workflow any time
|
||||
"""
|
||||
return SchedulerCommand.NONE
|
||||
|
||||
# FIXME: avoid the sandbox user's workflow at a running state for ever
|
||||
return SchedulerCommand.RESOURCE_LIMIT_REACHED
|
||||
25
dify/api/tasks/workflow_cfs_scheduler/entities.py
Normal file
25
dify/api/tasks/workflow_cfs_scheduler/entities.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from configs import dify_config
|
||||
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
|
||||
|
||||
# Determine queue names based on edition
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
# Cloud edition: separate queues for different tiers
|
||||
_professional_queue = "workflow_professional"
|
||||
_team_queue = "workflow_team"
|
||||
_sandbox_queue = "workflow_sandbox"
|
||||
AsyncWorkflowSystemStrategy = WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice
|
||||
else:
|
||||
# Community edition: single workflow queue (not dataset)
|
||||
_professional_queue = "workflow"
|
||||
_team_queue = "workflow"
|
||||
_sandbox_queue = "workflow"
|
||||
AsyncWorkflowSystemStrategy = WorkflowScheduleCFSPlanEntity.Strategy.Nop
|
||||
|
||||
|
||||
class AsyncWorkflowQueue(StrEnum):
|
||||
# Define constants
|
||||
PROFESSIONAL_QUEUE = _professional_queue
|
||||
TEAM_QUEUE = _team_queue
|
||||
SANDBOX_QUEUE = _sandbox_queue
|
||||
22
dify/api/tasks/workflow_draft_var_tasks.py
Normal file
22
dify/api/tasks/workflow_draft_var_tasks.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Celery tasks for asynchronous workflow execution storage operations.
|
||||
|
||||
These tasks provide asynchronous storage capabilities for workflow execution data,
|
||||
improving performance by offloading storage operations to background workers.
|
||||
"""
|
||||
|
||||
from celery import shared_task # type: ignore[import-untyped]
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
|
||||
|
||||
|
||||
@shared_task(queue="workflow_draft_var", bind=True, max_retries=3, default_retry_delay=60)
|
||||
def save_workflow_execution_task(
|
||||
self,
|
||||
deletions: list[DraftVarFileDeletion],
|
||||
):
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
srv = WorkflowDraftVariableService(session=session)
|
||||
srv.delete_workflow_draft_variable_file(deletions=deletions)
|
||||
136
dify/api/tasks/workflow_execution_tasks.py
Normal file
136
dify/api/tasks/workflow_execution_tasks.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Celery tasks for asynchronous workflow execution storage operations.
|
||||
|
||||
These tasks provide asynchronous storage capabilities for workflow execution data,
|
||||
improving performance by offloading storage operations to background workers.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from models import CreatorUserRole, WorkflowRun
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
|
||||
def save_workflow_execution_task(
|
||||
self,
|
||||
execution_data: dict,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
creator_user_id: str,
|
||||
creator_user_role: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Asynchronously save or update a workflow execution to the database.
|
||||
|
||||
Args:
|
||||
execution_data: Serialized WorkflowExecution data
|
||||
tenant_id: Tenant ID for multi-tenancy
|
||||
app_id: Application ID
|
||||
triggered_from: Source of the execution trigger
|
||||
creator_user_id: ID of the user who created the execution
|
||||
creator_user_role: Role of the user who created the execution
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create a new session for this task
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
# Deserialize execution data
|
||||
execution = WorkflowExecution.model_validate(execution_data)
|
||||
|
||||
# Check if workflow run already exists
|
||||
existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_))
|
||||
|
||||
if existing_run:
|
||||
# Update existing workflow run
|
||||
_update_workflow_run_from_execution(existing_run, execution)
|
||||
logger.debug("Updated existing workflow run: %s", execution.id_)
|
||||
else:
|
||||
# Create new workflow run
|
||||
workflow_run = _create_workflow_run_from_execution(
|
||||
execution=execution,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom(triggered_from),
|
||||
creator_user_id=creator_user_id,
|
||||
creator_user_role=CreatorUserRole(creator_user_role),
|
||||
)
|
||||
session.add(workflow_run)
|
||||
logger.debug("Created new workflow run: %s", execution.id_)
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save workflow execution %s", execution_data.get("id_", "unknown"))
|
||||
# Retry the task with exponential backoff
|
||||
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
|
||||
|
||||
|
||||
def _create_workflow_run_from_execution(
|
||||
execution: WorkflowExecution,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
creator_user_id: str,
|
||||
creator_user_role: CreatorUserRole,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Create a WorkflowRun database model from a WorkflowExecution domain entity.
|
||||
"""
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.id = execution.id_
|
||||
workflow_run.tenant_id = tenant_id
|
||||
workflow_run.app_id = app_id
|
||||
workflow_run.workflow_id = execution.workflow_id
|
||||
workflow_run.type = execution.workflow_type.value
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = execution.workflow_version
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph))
|
||||
workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs))
|
||||
workflow_run.status = execution.status.value
|
||||
workflow_run.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
workflow_run.error = execution.error_message
|
||||
workflow_run.elapsed_time = execution.elapsed_time
|
||||
workflow_run.total_tokens = execution.total_tokens
|
||||
workflow_run.total_steps = execution.total_steps
|
||||
workflow_run.created_by_role = creator_user_role.value
|
||||
workflow_run.created_by = creator_user_id
|
||||
workflow_run.created_at = execution.started_at
|
||||
workflow_run.finished_at = execution.finished_at
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
||||
def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution):
|
||||
"""
|
||||
Update a WorkflowRun database model from a WorkflowExecution domain entity.
|
||||
"""
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
workflow_run.status = execution.status.value
|
||||
workflow_run.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
workflow_run.error = execution.error_message
|
||||
workflow_run.elapsed_time = execution.elapsed_time
|
||||
workflow_run.total_tokens = execution.total_tokens
|
||||
workflow_run.total_steps = execution.total_steps
|
||||
workflow_run.finished_at = execution.finished_at
|
||||
169
dify/api/tasks/workflow_node_execution_tasks.py
Normal file
169
dify/api/tasks/workflow_node_execution_tasks.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Celery tasks for asynchronous workflow node execution storage operations.
|
||||
|
||||
These tasks provide asynchronous storage capabilities for workflow node execution data,
|
||||
improving performance by offloading storage operations to background workers.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from models import CreatorUserRole, WorkflowNodeExecutionModel
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
|
||||
def save_workflow_node_execution_task(
|
||||
self,
|
||||
execution_data: dict,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
creator_user_id: str,
|
||||
creator_user_role: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Asynchronously save or update a workflow node execution to the database.
|
||||
|
||||
Args:
|
||||
execution_data: Serialized WorkflowNodeExecution data
|
||||
tenant_id: Tenant ID for multi-tenancy
|
||||
app_id: Application ID
|
||||
triggered_from: Source of the execution trigger
|
||||
creator_user_id: ID of the user who created the execution
|
||||
creator_user_role: Role of the user who created the execution
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create a new session for this task
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
# Deserialize execution data
|
||||
execution = WorkflowNodeExecution.model_validate(execution_data)
|
||||
|
||||
# Check if node execution already exists
|
||||
existing_execution = session.scalar(
|
||||
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id)
|
||||
)
|
||||
|
||||
if existing_execution:
|
||||
# Update existing node execution
|
||||
_update_node_execution_from_domain(existing_execution, execution)
|
||||
logger.debug("Updated existing workflow node execution: %s", execution.id)
|
||||
else:
|
||||
# Create new node execution
|
||||
node_execution = _create_node_execution_from_domain(
|
||||
execution=execution,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from),
|
||||
creator_user_id=creator_user_id,
|
||||
creator_user_role=CreatorUserRole(creator_user_role),
|
||||
)
|
||||
session.add(node_execution)
|
||||
logger.debug("Created new workflow node execution: %s", execution.id)
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save workflow node execution %s", execution_data.get("id", "unknown"))
|
||||
# Retry the task with exponential backoff
|
||||
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
|
||||
|
||||
|
||||
def _create_node_execution_from_domain(
|
||||
execution: WorkflowNodeExecution,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom,
|
||||
creator_user_id: str,
|
||||
creator_user_role: CreatorUserRole,
|
||||
) -> WorkflowNodeExecutionModel:
|
||||
"""
|
||||
Create a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
|
||||
"""
|
||||
node_execution = WorkflowNodeExecutionModel()
|
||||
node_execution.id = execution.id
|
||||
node_execution.tenant_id = tenant_id
|
||||
node_execution.app_id = app_id
|
||||
node_execution.workflow_id = execution.workflow_id
|
||||
node_execution.triggered_from = triggered_from.value
|
||||
node_execution.workflow_run_id = execution.workflow_execution_id
|
||||
node_execution.index = execution.index
|
||||
node_execution.predecessor_node_id = execution.predecessor_node_id
|
||||
node_execution.node_id = execution.node_id
|
||||
node_execution.node_type = execution.node_type.value
|
||||
node_execution.title = execution.title
|
||||
node_execution.node_execution_id = execution.node_execution_id
|
||||
|
||||
# Serialize complex data as JSON
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
|
||||
node_execution.process_data = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
|
||||
)
|
||||
node_execution.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
# Convert metadata enum keys to strings for JSON serialization
|
||||
if execution.metadata:
|
||||
metadata_for_json = {
|
||||
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
|
||||
}
|
||||
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
|
||||
else:
|
||||
node_execution.execution_metadata = "{}"
|
||||
|
||||
node_execution.status = execution.status.value
|
||||
node_execution.error = execution.error
|
||||
node_execution.elapsed_time = execution.elapsed_time
|
||||
node_execution.created_by_role = creator_user_role.value
|
||||
node_execution.created_by = creator_user_id
|
||||
node_execution.created_at = execution.created_at
|
||||
node_execution.finished_at = execution.finished_at
|
||||
|
||||
return node_execution
|
||||
|
||||
|
||||
def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution):
|
||||
"""
|
||||
Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
|
||||
"""
|
||||
# Update serialized data
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
|
||||
node_execution.process_data = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
|
||||
)
|
||||
node_execution.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
# Convert metadata enum keys to strings for JSON serialization
|
||||
if execution.metadata:
|
||||
metadata_for_json = {
|
||||
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
|
||||
}
|
||||
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
|
||||
else:
|
||||
node_execution.execution_metadata = "{}"
|
||||
|
||||
# Update other fields
|
||||
node_execution.status = execution.status.value
|
||||
node_execution.error = execution.error
|
||||
node_execution.elapsed_time = execution.elapsed_time
|
||||
node_execution.finished_at = execution.finished_at
|
||||
73
dify/api/tasks/workflow_schedule_tasks.py
Normal file
73
dify/api/tasks/workflow_schedule_tasks.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import logging
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.nodes.trigger_schedule.exc import (
|
||||
ScheduleExecutionError,
|
||||
ScheduleNotFoundError,
|
||||
TenantOwnerNotFoundError,
|
||||
)
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from extensions.ext_database import db
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
from services.workflow.entities import ScheduleTriggerData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="schedule_executor")
|
||||
def run_schedule_trigger(schedule_id: str) -> None:
|
||||
"""
|
||||
Execute a scheduled workflow trigger.
|
||||
|
||||
Note: No retry logic needed as schedules will run again at next interval.
|
||||
The execution result is tracked via WorkflowTriggerLog.
|
||||
|
||||
Raises:
|
||||
ScheduleNotFoundError: If schedule doesn't exist
|
||||
TenantOwnerNotFoundError: If no owner/admin for tenant
|
||||
ScheduleExecutionError: If workflow trigger fails
|
||||
"""
|
||||
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
schedule = session.get(WorkflowSchedulePlan, schedule_id)
|
||||
if not schedule:
|
||||
raise ScheduleNotFoundError(f"Schedule {schedule_id} not found")
|
||||
|
||||
tenant_owner = ScheduleService.get_tenant_owner(session, schedule.tenant_id)
|
||||
if not tenant_owner:
|
||||
raise TenantOwnerNotFoundError(f"No owner or admin found for tenant {schedule.tenant_id}")
|
||||
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
|
||||
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
|
||||
return
|
||||
|
||||
try:
|
||||
# Production dispatch: Trigger the workflow normally
|
||||
response = AsyncWorkflowService.trigger_workflow_async(
|
||||
session=session,
|
||||
user=tenant_owner,
|
||||
trigger_data=ScheduleTriggerData(
|
||||
app_id=schedule.app_id,
|
||||
root_node_id=schedule.node_id,
|
||||
inputs={},
|
||||
tenant_id=schedule.tenant_id,
|
||||
),
|
||||
)
|
||||
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
|
||||
except Exception as e:
|
||||
quota_charge.refund()
|
||||
raise ScheduleExecutionError(
|
||||
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
|
||||
) from e
|
||||
Reference in New Issue
Block a user