dify
This commit is contained in:
40
dify/api/events/event_handlers/__init__.py
Normal file
40
dify/api/events/event_handlers/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from .clean_when_dataset_deleted import handle as handle_clean_when_dataset_deleted
|
||||
from .clean_when_document_deleted import handle as handle_clean_when_document_deleted
|
||||
from .create_document_index import handle as handle_create_document_index
|
||||
from .create_installed_app_when_app_created import handle as handle_create_installed_app_when_app_created
|
||||
from .create_site_record_when_app_created import handle as handle_create_site_record_when_app_created
|
||||
from .delete_tool_parameters_cache_when_sync_draft_workflow import (
|
||||
handle as handle_delete_tool_parameters_cache_when_sync_draft_workflow,
|
||||
)
|
||||
from .sync_plugin_trigger_when_app_created import handle as handle_sync_plugin_trigger_when_app_created
|
||||
from .sync_webhook_when_app_created import handle as handle_sync_webhook_when_app_created
|
||||
from .sync_workflow_schedule_when_app_published import handle as handle_sync_workflow_schedule_when_app_published
|
||||
from .update_app_dataset_join_when_app_model_config_updated import (
|
||||
handle as handle_update_app_dataset_join_when_app_model_config_updated,
|
||||
)
|
||||
from .update_app_dataset_join_when_app_published_workflow_updated import (
|
||||
handle as handle_update_app_dataset_join_when_app_published_workflow_updated,
|
||||
)
|
||||
from .update_app_triggers_when_app_published_workflow_updated import (
|
||||
handle as handle_update_app_triggers_when_app_published_workflow_updated,
|
||||
)
|
||||
|
||||
# Consolidated handler replaces both deduct_quota_when_message_created and
|
||||
# update_provider_last_used_at_when_message_created
|
||||
from .update_provider_when_message_created import handle as handle_update_provider_when_message_created
|
||||
|
||||
__all__ = [
|
||||
"handle_clean_when_dataset_deleted",
|
||||
"handle_clean_when_document_deleted",
|
||||
"handle_create_document_index",
|
||||
"handle_create_installed_app_when_app_created",
|
||||
"handle_create_site_record_when_app_created",
|
||||
"handle_delete_tool_parameters_cache_when_sync_draft_workflow",
|
||||
"handle_sync_plugin_trigger_when_app_created",
|
||||
"handle_sync_webhook_when_app_created",
|
||||
"handle_sync_workflow_schedule_when_app_published",
|
||||
"handle_update_app_dataset_join_when_app_model_config_updated",
|
||||
"handle_update_app_dataset_join_when_app_published_workflow_updated",
|
||||
"handle_update_app_triggers_when_app_published_workflow_updated",
|
||||
"handle_update_provider_when_message_created",
|
||||
]
|
||||
18
dify/api/events/event_handlers/clean_when_dataset_deleted.py
Normal file
18
dify/api/events/event_handlers/clean_when_dataset_deleted.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from events.dataset_event import dataset_was_deleted
|
||||
from models import Dataset
|
||||
from tasks.clean_dataset_task import clean_dataset_task
|
||||
|
||||
|
||||
@dataset_was_deleted.connect
|
||||
def handle(sender: Dataset, **kwargs):
|
||||
dataset = sender
|
||||
if not dataset.doc_form or not dataset.indexing_technique:
|
||||
return
|
||||
clean_dataset_task.delay(
|
||||
dataset.id,
|
||||
dataset.tenant_id,
|
||||
dataset.indexing_technique,
|
||||
dataset.index_struct,
|
||||
dataset.collection_binding_id,
|
||||
dataset.doc_form,
|
||||
)
|
||||
@@ -0,0 +1,13 @@
|
||||
from events.document_event import document_was_deleted
|
||||
from tasks.clean_document_task import clean_document_task
|
||||
|
||||
|
||||
@document_was_deleted.connect
|
||||
def handle(sender, **kwargs):
|
||||
document_id = sender
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
doc_form = kwargs.get("doc_form")
|
||||
file_id = kwargs.get("file_id")
|
||||
if not dataset_id or not doc_form:
|
||||
return
|
||||
clean_document_task.delay(document_id, dataset_id, doc_form, file_id)
|
||||
51
dify/api/events/event_handlers/create_document_index.py
Normal file
51
dify/api/events/event_handlers/create_document_index.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from events.document_index_event import document_index_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@document_index_created.connect
|
||||
def handle(sender, **kwargs):
|
||||
dataset_id = sender
|
||||
document_ids = kwargs.get("document_ids", [])
|
||||
documents = []
|
||||
start_at = time.perf_counter()
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not document:
|
||||
raise NotFound("Document not found")
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
@@ -0,0 +1,16 @@
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.model import InstalledApp
|
||||
|
||||
|
||||
@app_was_created.connect
|
||||
def handle(sender, **kwargs):
|
||||
"""Create an installed app when an app is created."""
|
||||
app = sender
|
||||
installed_app = InstalledApp(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
app_owner_tenant_id=app.tenant_id,
|
||||
)
|
||||
db.session.add(installed_app)
|
||||
db.session.commit()
|
||||
@@ -0,0 +1,26 @@
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.model import Site
|
||||
|
||||
|
||||
@app_was_created.connect
|
||||
def handle(sender, **kwargs):
|
||||
"""Create site record when an app is created."""
|
||||
app = sender
|
||||
account = kwargs.get("account")
|
||||
if account is not None:
|
||||
site = Site(
|
||||
app_id=app.id,
|
||||
title=app.name,
|
||||
icon_type=app.icon_type,
|
||||
icon=app.icon,
|
||||
icon_background=app.icon_background,
|
||||
default_language=account.interface_language,
|
||||
customize_token_strategy="not_allow",
|
||||
code=Site.generate_code(16),
|
||||
created_by=app.created_by,
|
||||
updated_by=app.updated_by,
|
||||
)
|
||||
|
||||
db.session.add(site)
|
||||
db.session.commit()
|
||||
@@ -0,0 +1,35 @@
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from events.app_event import app_draft_workflow_was_synced
|
||||
|
||||
|
||||
@app_draft_workflow_was_synced.connect
|
||||
def handle(sender, **kwargs):
|
||||
app = sender
|
||||
synced_draft_workflow = kwargs.get("synced_draft_workflow")
|
||||
if synced_draft_workflow is None:
|
||||
return
|
||||
for node_data in synced_draft_workflow.graph_dict.get("nodes", []):
|
||||
if node_data.get("data", {}).get("type") == NodeType.TOOL:
|
||||
try:
|
||||
tool_entity = ToolEntity.model_validate(node_data["data"])
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=tool_entity.provider_type,
|
||||
provider_id=tool_entity.provider_id,
|
||||
tool_name=tool_entity.tool_name,
|
||||
tenant_id=app.tenant_id,
|
||||
credential_id=tool_entity.credential_id,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=app.tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=tool_entity.provider_name,
|
||||
provider_type=tool_entity.provider_type,
|
||||
identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}",
|
||||
)
|
||||
manager.delete_tool_parameters_cache()
|
||||
except:
|
||||
# tool dose not exist
|
||||
pass
|
||||
@@ -0,0 +1,22 @@
|
||||
import logging
|
||||
|
||||
from events.app_event import app_draft_workflow_was_synced
|
||||
from models.model import App, AppMode
|
||||
from models.workflow import Workflow
|
||||
from services.trigger.trigger_service import TriggerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app_draft_workflow_was_synced.connect
|
||||
def handle(sender, synced_draft_workflow: Workflow, **kwargs):
|
||||
"""
|
||||
While creating a workflow or updating a workflow, we may need to sync
|
||||
its plugin trigger relationships in DB.
|
||||
"""
|
||||
app: App = sender
|
||||
if app.mode != AppMode.WORKFLOW.value:
|
||||
# only handle workflow app, chatflow is not supported yet
|
||||
return
|
||||
|
||||
TriggerService.sync_plugin_trigger_relationships(app, synced_draft_workflow)
|
||||
@@ -0,0 +1,22 @@
|
||||
import logging
|
||||
|
||||
from events.app_event import app_draft_workflow_was_synced
|
||||
from models.model import App, AppMode
|
||||
from models.workflow import Workflow
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app_draft_workflow_was_synced.connect
|
||||
def handle(sender, synced_draft_workflow: Workflow, **kwargs):
|
||||
"""
|
||||
While creating a workflow or updating a workflow, we may need to sync
|
||||
its webhook relationships in DB.
|
||||
"""
|
||||
app: App = sender
|
||||
if app.mode != AppMode.WORKFLOW.value:
|
||||
# only handle workflow app, chatflow is not supported yet
|
||||
return
|
||||
|
||||
WebhookService.sync_webhook_relationships(app, synced_draft_workflow)
|
||||
@@ -0,0 +1,86 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.nodes.trigger_schedule.entities import SchedulePlanUpdate
|
||||
from events.app_event import app_published_workflow_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models import AppMode, Workflow, WorkflowSchedulePlan
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app_published_workflow_was_updated.connect
|
||||
def handle(sender, **kwargs):
|
||||
"""
|
||||
Handle app published workflow update event to sync workflow_schedule_plans table.
|
||||
|
||||
When a workflow is published, this handler will:
|
||||
1. Extract schedule trigger nodes from the workflow graph
|
||||
2. Compare with existing workflow_schedule_plans records
|
||||
3. Create/update/delete schedule plans as needed
|
||||
"""
|
||||
app = sender
|
||||
if app.mode != AppMode.WORKFLOW.value:
|
||||
return
|
||||
|
||||
published_workflow = kwargs.get("published_workflow")
|
||||
published_workflow = cast(Workflow, published_workflow)
|
||||
|
||||
sync_schedule_from_workflow(tenant_id=app.tenant_id, app_id=app.id, workflow=published_workflow)
|
||||
|
||||
|
||||
def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow) -> WorkflowSchedulePlan | None:
|
||||
"""
|
||||
Sync schedule plan from workflow graph configuration.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
app_id: App ID
|
||||
workflow: Published workflow instance
|
||||
|
||||
Returns:
|
||||
Updated or created WorkflowSchedulePlan, or None if no schedule node
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
schedule_config = ScheduleService.extract_schedule_config(workflow)
|
||||
|
||||
existing_plan = session.scalar(
|
||||
select(WorkflowSchedulePlan).where(
|
||||
WorkflowSchedulePlan.tenant_id == tenant_id,
|
||||
WorkflowSchedulePlan.app_id == app_id,
|
||||
)
|
||||
)
|
||||
|
||||
if not schedule_config:
|
||||
if existing_plan:
|
||||
logger.info("No schedule node in workflow for app %s, removing schedule plan", app_id)
|
||||
ScheduleService.delete_schedule(session=session, schedule_id=existing_plan.id)
|
||||
session.commit()
|
||||
return None
|
||||
|
||||
if existing_plan:
|
||||
updates = SchedulePlanUpdate(
|
||||
node_id=schedule_config.node_id,
|
||||
cron_expression=schedule_config.cron_expression,
|
||||
timezone=schedule_config.timezone,
|
||||
)
|
||||
updated_plan = ScheduleService.update_schedule(
|
||||
session=session,
|
||||
schedule_id=existing_plan.id,
|
||||
updates=updates,
|
||||
)
|
||||
session.commit()
|
||||
return updated_plan
|
||||
else:
|
||||
new_plan = ScheduleService.create_schedule(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
config=schedule_config,
|
||||
)
|
||||
session.commit()
|
||||
return new_plan
|
||||
@@ -0,0 +1,70 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import AppDatasetJoin
|
||||
from models.model import AppModelConfig
|
||||
|
||||
|
||||
@app_model_config_was_updated.connect
|
||||
def handle(sender, **kwargs):
|
||||
app = sender
|
||||
app_model_config = kwargs.get("app_model_config")
|
||||
if app_model_config is None:
|
||||
return
|
||||
|
||||
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
|
||||
|
||||
app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all()
|
||||
|
||||
removed_dataset_ids: set[str] = set()
|
||||
if not app_dataset_joins:
|
||||
added_dataset_ids = dataset_ids
|
||||
else:
|
||||
old_dataset_ids: set[str] = set()
|
||||
old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
|
||||
|
||||
added_dataset_ids = dataset_ids - old_dataset_ids
|
||||
removed_dataset_ids = old_dataset_ids - dataset_ids
|
||||
|
||||
if removed_dataset_ids:
|
||||
for dataset_id in removed_dataset_ids:
|
||||
db.session.query(AppDatasetJoin).where(
|
||||
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
|
||||
).delete()
|
||||
|
||||
if added_dataset_ids:
|
||||
for dataset_id in added_dataset_ids:
|
||||
app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
|
||||
db.session.add(app_dataset_join)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[str]:
|
||||
dataset_ids: set[str] = set()
|
||||
if not app_model_config:
|
||||
return dataset_ids
|
||||
|
||||
agent_mode = app_model_config.agent_mode_dict
|
||||
|
||||
tools = agent_mode.get("tools", []) or []
|
||||
for tool in tools:
|
||||
if len(list(tool.keys())) != 1:
|
||||
continue
|
||||
|
||||
tool_type = list(tool.keys())[0]
|
||||
tool_config = list(tool.values())[0]
|
||||
if tool_type == "dataset":
|
||||
dataset_ids.add(tool_config.get("id"))
|
||||
|
||||
# get dataset from dataset_configs
|
||||
dataset_configs = app_model_config.dataset_configs_dict
|
||||
datasets = dataset_configs.get("datasets", {}) or {}
|
||||
for dataset in datasets.get("datasets", []) or []:
|
||||
keys = list(dataset.keys())
|
||||
if len(keys) == 1 and keys[0] == "dataset":
|
||||
if dataset["dataset"].get("id"):
|
||||
dataset_ids.add(dataset["dataset"].get("id"))
|
||||
|
||||
return dataset_ids
|
||||
@@ -0,0 +1,69 @@
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from events.app_event import app_published_workflow_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import AppDatasetJoin
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
@app_published_workflow_was_updated.connect
|
||||
def handle(sender, **kwargs):
|
||||
app = sender
|
||||
published_workflow = kwargs.get("published_workflow")
|
||||
published_workflow = cast(Workflow, published_workflow)
|
||||
|
||||
dataset_ids = get_dataset_ids_from_workflow(published_workflow)
|
||||
app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all()
|
||||
|
||||
removed_dataset_ids: set[str] = set()
|
||||
if not app_dataset_joins:
|
||||
added_dataset_ids = dataset_ids
|
||||
else:
|
||||
old_dataset_ids: set[str] = set()
|
||||
old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
|
||||
|
||||
added_dataset_ids = dataset_ids - old_dataset_ids
|
||||
removed_dataset_ids = old_dataset_ids - dataset_ids
|
||||
|
||||
if removed_dataset_ids:
|
||||
for dataset_id in removed_dataset_ids:
|
||||
db.session.query(AppDatasetJoin).where(
|
||||
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
|
||||
).delete()
|
||||
|
||||
if added_dataset_ids:
|
||||
for dataset_id in added_dataset_ids:
|
||||
app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
|
||||
db.session.add(app_dataset_join)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]:
|
||||
dataset_ids: set[str] = set()
|
||||
graph = published_workflow.graph_dict
|
||||
if not graph:
|
||||
return dataset_ids
|
||||
|
||||
nodes = graph.get("nodes", [])
|
||||
|
||||
# fetch all knowledge retrieval nodes
|
||||
knowledge_retrieval_nodes = [
|
||||
node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL
|
||||
]
|
||||
|
||||
if not knowledge_retrieval_nodes:
|
||||
return dataset_ids
|
||||
|
||||
for node in knowledge_retrieval_nodes:
|
||||
try:
|
||||
node_data = KnowledgeRetrievalNodeData.model_validate(node.get("data", {}))
|
||||
dataset_ids.update(dataset_id for dataset_id in node_data.dataset_ids)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return dataset_ids
|
||||
@@ -0,0 +1,114 @@
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.nodes import NodeType
|
||||
from events.app_event import app_published_workflow_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models import AppMode
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.trigger import AppTrigger
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
@app_published_workflow_was_updated.connect
|
||||
def handle(sender, **kwargs):
|
||||
"""
|
||||
Handle app published workflow update event to sync app_triggers table.
|
||||
|
||||
When a workflow is published, this handler will:
|
||||
1. Extract trigger nodes from the workflow graph
|
||||
2. Compare with existing app_triggers records
|
||||
3. Add new triggers and remove obsolete ones
|
||||
"""
|
||||
app = sender
|
||||
if app.mode != AppMode.WORKFLOW.value:
|
||||
return
|
||||
|
||||
published_workflow = kwargs.get("published_workflow")
|
||||
published_workflow = cast(Workflow, published_workflow)
|
||||
# Extract trigger info from workflow
|
||||
trigger_infos = get_trigger_infos_from_workflow(published_workflow)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get existing app triggers
|
||||
existing_triggers = (
|
||||
session.execute(
|
||||
select(AppTrigger).where(AppTrigger.tenant_id == app.tenant_id, AppTrigger.app_id == app.id)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Convert existing triggers to dict for easy lookup
|
||||
existing_triggers_map = {trigger.node_id: trigger for trigger in existing_triggers}
|
||||
|
||||
# Get current and new node IDs
|
||||
existing_node_ids = set(existing_triggers_map.keys())
|
||||
new_node_ids = {info["node_id"] for info in trigger_infos}
|
||||
|
||||
# Calculate changes
|
||||
added_node_ids = new_node_ids - existing_node_ids
|
||||
removed_node_ids = existing_node_ids - new_node_ids
|
||||
|
||||
# Remove obsolete triggers
|
||||
for node_id in removed_node_ids:
|
||||
session.delete(existing_triggers_map[node_id])
|
||||
|
||||
for trigger_info in trigger_infos:
|
||||
node_id = trigger_info["node_id"]
|
||||
|
||||
if node_id in added_node_ids:
|
||||
# Create new trigger
|
||||
app_trigger = AppTrigger(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
trigger_type=trigger_info["node_type"],
|
||||
title=trigger_info["node_title"],
|
||||
node_id=node_id,
|
||||
provider_name=trigger_info.get("node_provider_name", ""),
|
||||
status=AppTriggerStatus.ENABLED,
|
||||
)
|
||||
session.add(app_trigger)
|
||||
elif node_id in existing_node_ids:
|
||||
# Update existing trigger if needed
|
||||
existing_trigger = existing_triggers_map[node_id]
|
||||
new_title = trigger_info["node_title"]
|
||||
if new_title and existing_trigger.title != new_title:
|
||||
existing_trigger.title = new_title
|
||||
session.add(existing_trigger)
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_trigger_infos_from_workflow(published_workflow: Workflow) -> list[dict]:
|
||||
"""
|
||||
Extract trigger node information from the workflow graph.
|
||||
|
||||
Returns:
|
||||
List of trigger info dictionaries containing:
|
||||
- node_type: The type of the trigger node ('trigger-webhook', 'trigger-schedule', 'trigger-plugin')
|
||||
- node_id: The node ID in the workflow
|
||||
- node_title: The title of the node
|
||||
- node_provider_name: The name of the node's provider, only for plugin
|
||||
"""
|
||||
graph = published_workflow.graph_dict
|
||||
if not graph:
|
||||
return []
|
||||
|
||||
nodes = graph.get("nodes", [])
|
||||
trigger_types = {NodeType.TRIGGER_WEBHOOK.value, NodeType.TRIGGER_SCHEDULE.value, NodeType.TRIGGER_PLUGIN.value}
|
||||
|
||||
trigger_infos = [
|
||||
{
|
||||
"node_type": node.get("data", {}).get("type"),
|
||||
"node_id": node.get("id"),
|
||||
"node_title": node.get("data", {}).get("title"),
|
||||
"node_provider_name": node.get("data", {}).get("provider_name"),
|
||||
}
|
||||
for node in nodes
|
||||
if node.get("data", {}).get("type") in trigger_types
|
||||
]
|
||||
|
||||
return trigger_infos
|
||||
@@ -0,0 +1,291 @@
|
||||
import logging
|
||||
import time as time_module
|
||||
from datetime import datetime
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
from libs import datetime_utils
|
||||
from models.model import Message
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis cache key prefix for provider last used timestamps
|
||||
_PROVIDER_LAST_USED_CACHE_PREFIX = "provider:last_used"
|
||||
# Default TTL for cache entries (10 minutes)
|
||||
_CACHE_TTL_SECONDS = 600
|
||||
LAST_USED_UPDATE_WINDOW_SECONDS = 60 * 5
|
||||
|
||||
|
||||
def _get_provider_cache_key(tenant_id: str, provider_name: str) -> str:
|
||||
"""Generate Redis cache key for provider last used timestamp."""
|
||||
return f"{_PROVIDER_LAST_USED_CACHE_PREFIX}:{tenant_id}:{provider_name}"
|
||||
|
||||
|
||||
@redis_fallback(default_return=None)
|
||||
def _get_last_update_timestamp(cache_key: str) -> datetime | None:
|
||||
"""Get last update timestamp from Redis cache."""
|
||||
timestamp_str = redis_client.get(cache_key)
|
||||
if timestamp_str:
|
||||
return datetime.fromtimestamp(float(timestamp_str.decode("utf-8")))
|
||||
return None
|
||||
|
||||
|
||||
@redis_fallback()
|
||||
def _set_last_update_timestamp(cache_key: str, timestamp: datetime):
|
||||
"""Set last update timestamp in Redis cache with TTL."""
|
||||
redis_client.setex(cache_key, _CACHE_TTL_SECONDS, str(timestamp.timestamp()))
|
||||
|
||||
|
||||
class _ProviderUpdateFilters(BaseModel):
|
||||
"""Filters for identifying Provider records to update."""
|
||||
|
||||
tenant_id: str
|
||||
provider_name: str
|
||||
provider_type: str | None = None
|
||||
quota_type: str | None = None
|
||||
|
||||
|
||||
class _ProviderUpdateAdditionalFilters(BaseModel):
|
||||
"""Additional filters for Provider updates."""
|
||||
|
||||
quota_limit_check: bool = False
|
||||
|
||||
|
||||
class _ProviderUpdateValues(BaseModel):
|
||||
"""Values to update in Provider records."""
|
||||
|
||||
last_used: datetime | None = None
|
||||
quota_used: Any | None = None # Can be Provider.quota_used + int expression
|
||||
|
||||
|
||||
class _ProviderUpdateOperation(BaseModel):
|
||||
"""A single Provider update operation."""
|
||||
|
||||
filters: _ProviderUpdateFilters
|
||||
values: _ProviderUpdateValues
|
||||
additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
|
||||
description: str = "unknown"
|
||||
|
||||
|
||||
@message_was_created.connect
|
||||
def handle(sender: Message, **kwargs):
|
||||
"""
|
||||
Consolidated handler for Provider updates when a message is created.
|
||||
|
||||
This handler replaces both:
|
||||
- update_provider_last_used_at_when_message_created
|
||||
- deduct_quota_when_message_created
|
||||
|
||||
By performing all Provider updates in a single transaction, we ensure
|
||||
consistency and efficiency when updating Provider records.
|
||||
"""
|
||||
message = sender
|
||||
application_generate_entity = kwargs.get("application_generate_entity")
|
||||
|
||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||
return
|
||||
|
||||
tenant_id = application_generate_entity.app_config.tenant_id
|
||||
provider_name = application_generate_entity.model_conf.provider
|
||||
current_time = datetime_utils.naive_utc_now()
|
||||
|
||||
# Prepare updates for both scenarios
|
||||
updates_to_perform: list[_ProviderUpdateOperation] = []
|
||||
|
||||
# 1. Always update last_used for the provider
|
||||
basic_update = _ProviderUpdateOperation(
|
||||
filters=_ProviderUpdateFilters(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
),
|
||||
values=_ProviderUpdateValues(last_used=current_time),
|
||||
description="basic_last_used_update",
|
||||
)
|
||||
logger.info("provider used, tenant_id=%s, provider_name=%s", tenant_id, provider_name)
|
||||
updates_to_perform.append(basic_update)
|
||||
|
||||
# 2. Check if we need to deduct quota (system provider only)
|
||||
model_config = application_generate_entity.model_conf
|
||||
provider_model_bundle = model_config.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if (
|
||||
provider_configuration.using_provider_type == ProviderType.SYSTEM
|
||||
and provider_configuration.system_configuration
|
||||
and provider_configuration.system_configuration.current_quota_type is not None
|
||||
):
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
# Calculate quota usage
|
||||
used_quota = _calculate_quota_usage(
|
||||
message=message,
|
||||
system_configuration=system_configuration,
|
||||
model_name=model_config.model,
|
||||
)
|
||||
|
||||
if used_quota is not None:
|
||||
quota_update = _ProviderUpdateOperation(
|
||||
filters=_ProviderUpdateFilters(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=ModelProviderID(model_config.provider).provider_name,
|
||||
provider_type=ProviderType.SYSTEM,
|
||||
quota_type=provider_configuration.system_configuration.current_quota_type.value,
|
||||
),
|
||||
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
|
||||
additional_filters=_ProviderUpdateAdditionalFilters(
|
||||
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
|
||||
),
|
||||
description="quota_deduction_update",
|
||||
)
|
||||
updates_to_perform.append(quota_update)
|
||||
|
||||
# Execute all updates
|
||||
start_time = time_module.perf_counter()
|
||||
try:
|
||||
_execute_provider_updates(updates_to_perform)
|
||||
|
||||
# Log successful completion with timing
|
||||
duration = time_module.perf_counter() - start_time
|
||||
|
||||
logger.info(
|
||||
"Provider updates completed successfully. Updates: %s, Duration: %s s, Tenant: %s, Provider: %s",
|
||||
len(updates_to_perform),
|
||||
duration,
|
||||
tenant_id,
|
||||
provider_name,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Log failure with timing and context
|
||||
duration = time_module.perf_counter() - start_time
|
||||
|
||||
logger.exception(
|
||||
"Provider updates failed after %s s. Updates: %s, Tenant: %s, Provider: %s",
|
||||
duration,
|
||||
len(updates_to_perform),
|
||||
tenant_id,
|
||||
provider_name,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _calculate_quota_usage(
|
||||
*, message: Message, system_configuration: SystemConfiguration, model_name: str
|
||||
) -> int | None:
|
||||
"""Calculate quota usage based on message tokens and quota type."""
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return None
|
||||
break
|
||||
if quota_unit is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
tokens = message.message_tokens + message.answer_tokens
|
||||
return tokens
|
||||
if quota_unit == QuotaUnit.CREDITS:
|
||||
tokens = dify_config.get_model_credits(model_name)
|
||||
return tokens
|
||||
elif quota_unit == QuotaUnit.TIMES:
|
||||
return 1
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Failed to calculate quota usage")
|
||||
return None
|
||||
|
||||
|
||||
def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
|
||||
"""Execute all Provider updates in a single transaction."""
|
||||
if not updates_to_perform:
|
||||
return
|
||||
|
||||
updates_to_perform = sorted(updates_to_perform, key=lambda i: (i.filters.tenant_id, i.filters.provider_name))
|
||||
|
||||
# Use SQLAlchemy's context manager for transaction management
|
||||
# This automatically handles commit/rollback
|
||||
with Session(db.engine) as session, session.begin():
|
||||
# Use a single transaction for all updates
|
||||
for update_operation in updates_to_perform:
|
||||
filters = update_operation.filters
|
||||
values = update_operation.values
|
||||
additional_filters = update_operation.additional_filters
|
||||
description = update_operation.description
|
||||
|
||||
# Build the where conditions
|
||||
where_conditions = [
|
||||
Provider.tenant_id == filters.tenant_id,
|
||||
Provider.provider_name == filters.provider_name,
|
||||
]
|
||||
|
||||
# Add additional filters if specified
|
||||
if filters.provider_type is not None:
|
||||
where_conditions.append(Provider.provider_type == filters.provider_type)
|
||||
if filters.quota_type is not None:
|
||||
where_conditions.append(Provider.quota_type == filters.quota_type)
|
||||
if additional_filters.quota_limit_check:
|
||||
where_conditions.append(Provider.quota_limit > Provider.quota_used)
|
||||
|
||||
# Prepare values dict for SQLAlchemy update
|
||||
update_values = {}
|
||||
|
||||
# NOTE: For frequently used providers under high load, this implementation may experience
|
||||
# race conditions or update contention despite the time-window optimization:
|
||||
# 1. Multiple concurrent requests might check the same cache key simultaneously
|
||||
# 2. Redis cache operations are not atomic with database updates
|
||||
# 3. Heavy providers could still face database lock contention during peak usage
|
||||
# The current implementation is acceptable for most scenarios, but future optimization
|
||||
# considerations could include: batched updates, or async processing.
|
||||
if values.last_used is not None:
|
||||
cache_key = _get_provider_cache_key(filters.tenant_id, filters.provider_name)
|
||||
now = datetime_utils.naive_utc_now()
|
||||
last_update = _get_last_update_timestamp(cache_key)
|
||||
|
||||
if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS:
|
||||
update_values["last_used"] = values.last_used
|
||||
_set_last_update_timestamp(cache_key, now)
|
||||
|
||||
if values.quota_used is not None:
|
||||
update_values["quota_used"] = values.quota_used
|
||||
# Skip the current update operation if no updates are required.
|
||||
if not update_values:
|
||||
continue
|
||||
|
||||
# Build and execute the update statement
|
||||
stmt = update(Provider).where(*where_conditions).values(**update_values)
|
||||
result = cast(CursorResult, session.execute(stmt))
|
||||
rows_affected = result.rowcount
|
||||
|
||||
logger.debug(
|
||||
"Provider update (%s): %s rows affected. Filters: %s, Values: %s",
|
||||
description,
|
||||
rows_affected,
|
||||
filters.model_dump(),
|
||||
update_values,
|
||||
)
|
||||
|
||||
# If no rows were affected for quota updates, log a warning
|
||||
if rows_affected == 0 and description == "quota_deduction_update":
|
||||
logger.warning(
|
||||
"No Provider rows updated for quota deduction. "
|
||||
"This may indicate quota limit exceeded or provider not found. "
|
||||
"Filters: %s",
|
||||
filters.model_dump(),
|
||||
)
|
||||
|
||||
logger.debug("Successfully processed %s Provider updates", len(updates_to_perform))
|
||||
Reference in New Issue
Block a user