This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

View File

@@ -0,0 +1,13 @@
from blinker import signal
# sender: app
app_was_created = signal("app-was-created")
# sender: app, kwargs: app_model_config
app_model_config_was_updated = signal("app-model-config-was-updated")
# sender: app, kwargs: published_workflow
app_published_workflow_was_updated = signal("app-published-workflow-was-updated")
# sender: app, kwargs: synced_draft_workflow
app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced")

View File

@@ -0,0 +1,4 @@
from blinker import signal
# sender: dataset
dataset_was_deleted = signal("dataset-was-deleted")

View File

@@ -0,0 +1,4 @@
from blinker import signal
# sender: document
document_was_deleted = signal("document-was-deleted")

View File

@@ -0,0 +1,4 @@
from blinker import signal
# sender: document
document_index_created = signal("document-index-created")

View File

@@ -0,0 +1,40 @@
from .clean_when_dataset_deleted import handle as handle_clean_when_dataset_deleted
from .clean_when_document_deleted import handle as handle_clean_when_document_deleted
from .create_document_index import handle as handle_create_document_index
from .create_installed_app_when_app_created import handle as handle_create_installed_app_when_app_created
from .create_site_record_when_app_created import handle as handle_create_site_record_when_app_created
from .delete_tool_parameters_cache_when_sync_draft_workflow import (
handle as handle_delete_tool_parameters_cache_when_sync_draft_workflow,
)
from .sync_plugin_trigger_when_app_created import handle as handle_sync_plugin_trigger_when_app_created
from .sync_webhook_when_app_created import handle as handle_sync_webhook_when_app_created
from .sync_workflow_schedule_when_app_published import handle as handle_sync_workflow_schedule_when_app_published
from .update_app_dataset_join_when_app_model_config_updated import (
handle as handle_update_app_dataset_join_when_app_model_config_updated,
)
from .update_app_dataset_join_when_app_published_workflow_updated import (
handle as handle_update_app_dataset_join_when_app_published_workflow_updated,
)
from .update_app_triggers_when_app_published_workflow_updated import (
handle as handle_update_app_triggers_when_app_published_workflow_updated,
)
# Consolidated handler replaces both deduct_quota_when_message_created and
# update_provider_last_used_at_when_message_created
from .update_provider_when_message_created import handle as handle_update_provider_when_message_created
__all__ = [
"handle_clean_when_dataset_deleted",
"handle_clean_when_document_deleted",
"handle_create_document_index",
"handle_create_installed_app_when_app_created",
"handle_create_site_record_when_app_created",
"handle_delete_tool_parameters_cache_when_sync_draft_workflow",
"handle_sync_plugin_trigger_when_app_created",
"handle_sync_webhook_when_app_created",
"handle_sync_workflow_schedule_when_app_published",
"handle_update_app_dataset_join_when_app_model_config_updated",
"handle_update_app_dataset_join_when_app_published_workflow_updated",
"handle_update_app_triggers_when_app_published_workflow_updated",
"handle_update_provider_when_message_created",
]

View File

@@ -0,0 +1,18 @@
from events.dataset_event import dataset_was_deleted
from models import Dataset
from tasks.clean_dataset_task import clean_dataset_task
@dataset_was_deleted.connect
def handle(sender: Dataset, **kwargs):
dataset = sender
if not dataset.doc_form or not dataset.indexing_technique:
return
clean_dataset_task.delay(
dataset.id,
dataset.tenant_id,
dataset.indexing_technique,
dataset.index_struct,
dataset.collection_binding_id,
dataset.doc_form,
)

View File

@@ -0,0 +1,13 @@
from events.document_event import document_was_deleted
from tasks.clean_document_task import clean_document_task
@document_was_deleted.connect
def handle(sender, **kwargs):
document_id = sender
dataset_id = kwargs.get("dataset_id")
doc_form = kwargs.get("doc_form")
file_id = kwargs.get("file_id")
if not dataset_id or not doc_form:
return
clean_document_task.delay(document_id, dataset_id, doc_form, file_id)

View File

@@ -0,0 +1,51 @@
import contextlib
import logging
import time
import click
from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from events.document_index_event import document_index_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Document
logger = logging.getLogger(__name__)
@document_index_created.connect
def handle(sender, **kwargs):
dataset_id = sender
document_ids = kwargs.get("document_ids", [])
documents = []
start_at = time.perf_counter()
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document)
.where(
Document.id == document_id,
Document.dataset_id == dataset_id,
)
.first()
)
if not document:
raise NotFound("Document not found")
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
with contextlib.suppress(Exception):
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))

View File

@@ -0,0 +1,16 @@
from events.app_event import app_was_created
from extensions.ext_database import db
from models.model import InstalledApp
@app_was_created.connect
def handle(sender, **kwargs):
"""Create an installed app when an app is created."""
app = sender
installed_app = InstalledApp(
tenant_id=app.tenant_id,
app_id=app.id,
app_owner_tenant_id=app.tenant_id,
)
db.session.add(installed_app)
db.session.commit()

View File

@@ -0,0 +1,26 @@
from events.app_event import app_was_created
from extensions.ext_database import db
from models.model import Site
@app_was_created.connect
def handle(sender, **kwargs):
"""Create site record when an app is created."""
app = sender
account = kwargs.get("account")
if account is not None:
site = Site(
app_id=app.id,
title=app.name,
icon_type=app.icon_type,
icon=app.icon,
icon_background=app.icon_background,
default_language=account.interface_language,
customize_token_strategy="not_allow",
code=Site.generate_code(16),
created_by=app.created_by,
updated_by=app.updated_by,
)
db.session.add(site)
db.session.commit()

View File

@@ -0,0 +1,35 @@
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolEntity
from events.app_event import app_draft_workflow_was_synced
@app_draft_workflow_was_synced.connect
def handle(sender, **kwargs):
app = sender
synced_draft_workflow = kwargs.get("synced_draft_workflow")
if synced_draft_workflow is None:
return
for node_data in synced_draft_workflow.graph_dict.get("nodes", []):
if node_data.get("data", {}).get("type") == NodeType.TOOL:
try:
tool_entity = ToolEntity.model_validate(node_data["data"])
tool_runtime = ToolManager.get_tool_runtime(
provider_type=tool_entity.provider_type,
provider_id=tool_entity.provider_id,
tool_name=tool_entity.tool_name,
tenant_id=app.tenant_id,
credential_id=tool_entity.credential_id,
)
manager = ToolParameterConfigurationManager(
tenant_id=app.tenant_id,
tool_runtime=tool_runtime,
provider_name=tool_entity.provider_name,
provider_type=tool_entity.provider_type,
identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}",
)
manager.delete_tool_parameters_cache()
except:
# tool dose not exist
pass

View File

@@ -0,0 +1,22 @@
import logging
from events.app_event import app_draft_workflow_was_synced
from models.model import App, AppMode
from models.workflow import Workflow
from services.trigger.trigger_service import TriggerService
logger = logging.getLogger(__name__)
@app_draft_workflow_was_synced.connect
def handle(sender, synced_draft_workflow: Workflow, **kwargs):
"""
While creating a workflow or updating a workflow, we may need to sync
its plugin trigger relationships in DB.
"""
app: App = sender
if app.mode != AppMode.WORKFLOW.value:
# only handle workflow app, chatflow is not supported yet
return
TriggerService.sync_plugin_trigger_relationships(app, synced_draft_workflow)

View File

@@ -0,0 +1,22 @@
import logging
from events.app_event import app_draft_workflow_was_synced
from models.model import App, AppMode
from models.workflow import Workflow
from services.trigger.webhook_service import WebhookService
logger = logging.getLogger(__name__)
@app_draft_workflow_was_synced.connect
def handle(sender, synced_draft_workflow: Workflow, **kwargs):
"""
While creating a workflow or updating a workflow, we may need to sync
its webhook relationships in DB.
"""
app: App = sender
if app.mode != AppMode.WORKFLOW.value:
# only handle workflow app, chatflow is not supported yet
return
WebhookService.sync_webhook_relationships(app, synced_draft_workflow)

View File

@@ -0,0 +1,86 @@
import logging
from typing import cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.workflow.nodes.trigger_schedule.entities import SchedulePlanUpdate
from events.app_event import app_published_workflow_was_updated
from extensions.ext_database import db
from models import AppMode, Workflow, WorkflowSchedulePlan
from services.trigger.schedule_service import ScheduleService
logger = logging.getLogger(__name__)
@app_published_workflow_was_updated.connect
def handle(sender, **kwargs):
"""
Handle app published workflow update event to sync workflow_schedule_plans table.
When a workflow is published, this handler will:
1. Extract schedule trigger nodes from the workflow graph
2. Compare with existing workflow_schedule_plans records
3. Create/update/delete schedule plans as needed
"""
app = sender
if app.mode != AppMode.WORKFLOW.value:
return
published_workflow = kwargs.get("published_workflow")
published_workflow = cast(Workflow, published_workflow)
sync_schedule_from_workflow(tenant_id=app.tenant_id, app_id=app.id, workflow=published_workflow)
def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow) -> WorkflowSchedulePlan | None:
"""
Sync schedule plan from workflow graph configuration.
Args:
tenant_id: Tenant ID
app_id: App ID
workflow: Published workflow instance
Returns:
Updated or created WorkflowSchedulePlan, or None if no schedule node
"""
with Session(db.engine) as session:
schedule_config = ScheduleService.extract_schedule_config(workflow)
existing_plan = session.scalar(
select(WorkflowSchedulePlan).where(
WorkflowSchedulePlan.tenant_id == tenant_id,
WorkflowSchedulePlan.app_id == app_id,
)
)
if not schedule_config:
if existing_plan:
logger.info("No schedule node in workflow for app %s, removing schedule plan", app_id)
ScheduleService.delete_schedule(session=session, schedule_id=existing_plan.id)
session.commit()
return None
if existing_plan:
updates = SchedulePlanUpdate(
node_id=schedule_config.node_id,
cron_expression=schedule_config.cron_expression,
timezone=schedule_config.timezone,
)
updated_plan = ScheduleService.update_schedule(
session=session,
schedule_id=existing_plan.id,
updates=updates,
)
session.commit()
return updated_plan
else:
new_plan = ScheduleService.create_schedule(
session=session,
tenant_id=tenant_id,
app_id=app_id,
config=schedule_config,
)
session.commit()
return new_plan

View File

@@ -0,0 +1,70 @@
from sqlalchemy import select
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from models.dataset import AppDatasetJoin
from models.model import AppModelConfig
@app_model_config_was_updated.connect
def handle(sender, **kwargs):
app = sender
app_model_config = kwargs.get("app_model_config")
if app_model_config is None:
return
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all()
removed_dataset_ids: set[str] = set()
if not app_dataset_joins:
added_dataset_ids = dataset_ids
else:
old_dataset_ids: set[str] = set()
old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
added_dataset_ids = dataset_ids - old_dataset_ids
removed_dataset_ids = old_dataset_ids - dataset_ids
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).where(
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete()
if added_dataset_ids:
for dataset_id in added_dataset_ids:
app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
db.session.add(app_dataset_join)
db.session.commit()
def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[str]:
dataset_ids: set[str] = set()
if not app_model_config:
return dataset_ids
agent_mode = app_model_config.agent_mode_dict
tools = agent_mode.get("tools", []) or []
for tool in tools:
if len(list(tool.keys())) != 1:
continue
tool_type = list(tool.keys())[0]
tool_config = list(tool.values())[0]
if tool_type == "dataset":
dataset_ids.add(tool_config.get("id"))
# get dataset from dataset_configs
dataset_configs = app_model_config.dataset_configs_dict
datasets = dataset_configs.get("datasets", {}) or {}
for dataset in datasets.get("datasets", []) or []:
keys = list(dataset.keys())
if len(keys) == 1 and keys[0] == "dataset":
if dataset["dataset"].get("id"):
dataset_ids.add(dataset["dataset"].get("id"))
return dataset_ids

View File

@@ -0,0 +1,69 @@
from typing import cast
from sqlalchemy import select
from core.workflow.nodes import NodeType
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from events.app_event import app_published_workflow_was_updated
from extensions.ext_database import db
from models.dataset import AppDatasetJoin
from models.workflow import Workflow
@app_published_workflow_was_updated.connect
def handle(sender, **kwargs):
app = sender
published_workflow = kwargs.get("published_workflow")
published_workflow = cast(Workflow, published_workflow)
dataset_ids = get_dataset_ids_from_workflow(published_workflow)
app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all()
removed_dataset_ids: set[str] = set()
if not app_dataset_joins:
added_dataset_ids = dataset_ids
else:
old_dataset_ids: set[str] = set()
old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
added_dataset_ids = dataset_ids - old_dataset_ids
removed_dataset_ids = old_dataset_ids - dataset_ids
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).where(
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete()
if added_dataset_ids:
for dataset_id in added_dataset_ids:
app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
db.session.add(app_dataset_join)
db.session.commit()
def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]:
dataset_ids: set[str] = set()
graph = published_workflow.graph_dict
if not graph:
return dataset_ids
nodes = graph.get("nodes", [])
# fetch all knowledge retrieval nodes
knowledge_retrieval_nodes = [
node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL
]
if not knowledge_retrieval_nodes:
return dataset_ids
for node in knowledge_retrieval_nodes:
try:
node_data = KnowledgeRetrievalNodeData.model_validate(node.get("data", {}))
dataset_ids.update(dataset_id for dataset_id in node_data.dataset_ids)
except Exception:
continue
return dataset_ids

View File

@@ -0,0 +1,114 @@
from typing import cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.workflow.nodes import NodeType
from events.app_event import app_published_workflow_was_updated
from extensions.ext_database import db
from models import AppMode
from models.enums import AppTriggerStatus
from models.trigger import AppTrigger
from models.workflow import Workflow
@app_published_workflow_was_updated.connect
def handle(sender, **kwargs):
"""
Handle app published workflow update event to sync app_triggers table.
When a workflow is published, this handler will:
1. Extract trigger nodes from the workflow graph
2. Compare with existing app_triggers records
3. Add new triggers and remove obsolete ones
"""
app = sender
if app.mode != AppMode.WORKFLOW.value:
return
published_workflow = kwargs.get("published_workflow")
published_workflow = cast(Workflow, published_workflow)
# Extract trigger info from workflow
trigger_infos = get_trigger_infos_from_workflow(published_workflow)
with Session(db.engine) as session:
# Get existing app triggers
existing_triggers = (
session.execute(
select(AppTrigger).where(AppTrigger.tenant_id == app.tenant_id, AppTrigger.app_id == app.id)
)
.scalars()
.all()
)
# Convert existing triggers to dict for easy lookup
existing_triggers_map = {trigger.node_id: trigger for trigger in existing_triggers}
# Get current and new node IDs
existing_node_ids = set(existing_triggers_map.keys())
new_node_ids = {info["node_id"] for info in trigger_infos}
# Calculate changes
added_node_ids = new_node_ids - existing_node_ids
removed_node_ids = existing_node_ids - new_node_ids
# Remove obsolete triggers
for node_id in removed_node_ids:
session.delete(existing_triggers_map[node_id])
for trigger_info in trigger_infos:
node_id = trigger_info["node_id"]
if node_id in added_node_ids:
# Create new trigger
app_trigger = AppTrigger(
tenant_id=app.tenant_id,
app_id=app.id,
trigger_type=trigger_info["node_type"],
title=trigger_info["node_title"],
node_id=node_id,
provider_name=trigger_info.get("node_provider_name", ""),
status=AppTriggerStatus.ENABLED,
)
session.add(app_trigger)
elif node_id in existing_node_ids:
# Update existing trigger if needed
existing_trigger = existing_triggers_map[node_id]
new_title = trigger_info["node_title"]
if new_title and existing_trigger.title != new_title:
existing_trigger.title = new_title
session.add(existing_trigger)
session.commit()
def get_trigger_infos_from_workflow(published_workflow: Workflow) -> list[dict]:
"""
Extract trigger node information from the workflow graph.
Returns:
List of trigger info dictionaries containing:
- node_type: The type of the trigger node ('trigger-webhook', 'trigger-schedule', 'trigger-plugin')
- node_id: The node ID in the workflow
- node_title: The title of the node
- node_provider_name: The name of the node's provider, only for plugin
"""
graph = published_workflow.graph_dict
if not graph:
return []
nodes = graph.get("nodes", [])
trigger_types = {NodeType.TRIGGER_WEBHOOK.value, NodeType.TRIGGER_SCHEDULE.value, NodeType.TRIGGER_PLUGIN.value}
trigger_infos = [
{
"node_type": node.get("data", {}).get("type"),
"node_id": node.get("id"),
"node_title": node.get("data", {}).get("title"),
"node_provider_name": node.get("data", {}).get("provider_name"),
}
for node in nodes
if node.get("data", {}).get("type") in trigger_types
]
return trigger_infos

View File

@@ -0,0 +1,291 @@
import logging
import time as time_module
from datetime import datetime
from typing import Any, cast
from pydantic import BaseModel
from sqlalchemy import update
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
from libs import datetime_utils
from models.model import Message
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
# Redis cache key prefix for provider last used timestamps
_PROVIDER_LAST_USED_CACHE_PREFIX = "provider:last_used"
# Default TTL for cache entries (10 minutes)
_CACHE_TTL_SECONDS = 600
LAST_USED_UPDATE_WINDOW_SECONDS = 60 * 5
def _get_provider_cache_key(tenant_id: str, provider_name: str) -> str:
"""Generate Redis cache key for provider last used timestamp."""
return f"{_PROVIDER_LAST_USED_CACHE_PREFIX}:{tenant_id}:{provider_name}"
@redis_fallback(default_return=None)
def _get_last_update_timestamp(cache_key: str) -> datetime | None:
"""Get last update timestamp from Redis cache."""
timestamp_str = redis_client.get(cache_key)
if timestamp_str:
return datetime.fromtimestamp(float(timestamp_str.decode("utf-8")))
return None
@redis_fallback()
def _set_last_update_timestamp(cache_key: str, timestamp: datetime):
"""Set last update timestamp in Redis cache with TTL."""
redis_client.setex(cache_key, _CACHE_TTL_SECONDS, str(timestamp.timestamp()))
class _ProviderUpdateFilters(BaseModel):
"""Filters for identifying Provider records to update."""
tenant_id: str
provider_name: str
provider_type: str | None = None
quota_type: str | None = None
class _ProviderUpdateAdditionalFilters(BaseModel):
"""Additional filters for Provider updates."""
quota_limit_check: bool = False
class _ProviderUpdateValues(BaseModel):
"""Values to update in Provider records."""
last_used: datetime | None = None
quota_used: Any | None = None # Can be Provider.quota_used + int expression
class _ProviderUpdateOperation(BaseModel):
"""A single Provider update operation."""
filters: _ProviderUpdateFilters
values: _ProviderUpdateValues
additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
description: str = "unknown"
@message_was_created.connect
def handle(sender: Message, **kwargs):
"""
Consolidated handler for Provider updates when a message is created.
This handler replaces both:
- update_provider_last_used_at_when_message_created
- deduct_quota_when_message_created
By performing all Provider updates in a single transaction, we ensure
consistency and efficiency when updating Provider records.
"""
message = sender
application_generate_entity = kwargs.get("application_generate_entity")
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
return
tenant_id = application_generate_entity.app_config.tenant_id
provider_name = application_generate_entity.model_conf.provider
current_time = datetime_utils.naive_utc_now()
# Prepare updates for both scenarios
updates_to_perform: list[_ProviderUpdateOperation] = []
# 1. Always update last_used for the provider
basic_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
tenant_id=tenant_id,
provider_name=provider_name,
),
values=_ProviderUpdateValues(last_used=current_time),
description="basic_last_used_update",
)
logger.info("provider used, tenant_id=%s, provider_name=%s", tenant_id, provider_name)
updates_to_perform.append(basic_update)
# 2. Check if we need to deduct quota (system provider only)
model_config = application_generate_entity.model_conf
provider_model_bundle = model_config.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if (
provider_configuration.using_provider_type == ProviderType.SYSTEM
and provider_configuration.system_configuration
and provider_configuration.system_configuration.current_quota_type is not None
):
system_configuration = provider_configuration.system_configuration
# Calculate quota usage
used_quota = _calculate_quota_usage(
message=message,
system_configuration=system_configuration,
model_name=model_config.model,
)
if used_quota is not None:
quota_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
tenant_id=tenant_id,
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
),
description="quota_deduction_update",
)
updates_to_perform.append(quota_update)
# Execute all updates
start_time = time_module.perf_counter()
try:
_execute_provider_updates(updates_to_perform)
# Log successful completion with timing
duration = time_module.perf_counter() - start_time
logger.info(
"Provider updates completed successfully. Updates: %s, Duration: %s s, Tenant: %s, Provider: %s",
len(updates_to_perform),
duration,
tenant_id,
provider_name,
)
except Exception:
# Log failure with timing and context
duration = time_module.perf_counter() - start_time
logger.exception(
"Provider updates failed after %s s. Updates: %s, Tenant: %s, Provider: %s",
duration,
len(updates_to_perform),
tenant_id,
provider_name,
)
raise
def _calculate_quota_usage(
*, message: Message, system_configuration: SystemConfiguration, model_name: str
) -> int | None:
"""Calculate quota usage based on message tokens and quota type."""
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return None
break
if quota_unit is None:
return None
try:
if quota_unit == QuotaUnit.TOKENS:
tokens = message.message_tokens + message.answer_tokens
return tokens
if quota_unit == QuotaUnit.CREDITS:
tokens = dify_config.get_model_credits(model_name)
return tokens
elif quota_unit == QuotaUnit.TIMES:
return 1
return None
except Exception:
logger.exception("Failed to calculate quota usage")
return None
def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
"""Execute all Provider updates in a single transaction."""
if not updates_to_perform:
return
updates_to_perform = sorted(updates_to_perform, key=lambda i: (i.filters.tenant_id, i.filters.provider_name))
# Use SQLAlchemy's context manager for transaction management
# This automatically handles commit/rollback
with Session(db.engine) as session, session.begin():
# Use a single transaction for all updates
for update_operation in updates_to_perform:
filters = update_operation.filters
values = update_operation.values
additional_filters = update_operation.additional_filters
description = update_operation.description
# Build the where conditions
where_conditions = [
Provider.tenant_id == filters.tenant_id,
Provider.provider_name == filters.provider_name,
]
# Add additional filters if specified
if filters.provider_type is not None:
where_conditions.append(Provider.provider_type == filters.provider_type)
if filters.quota_type is not None:
where_conditions.append(Provider.quota_type == filters.quota_type)
if additional_filters.quota_limit_check:
where_conditions.append(Provider.quota_limit > Provider.quota_used)
# Prepare values dict for SQLAlchemy update
update_values = {}
# NOTE: For frequently used providers under high load, this implementation may experience
# race conditions or update contention despite the time-window optimization:
# 1. Multiple concurrent requests might check the same cache key simultaneously
# 2. Redis cache operations are not atomic with database updates
# 3. Heavy providers could still face database lock contention during peak usage
# The current implementation is acceptable for most scenarios, but future optimization
# considerations could include: batched updates, or async processing.
if values.last_used is not None:
cache_key = _get_provider_cache_key(filters.tenant_id, filters.provider_name)
now = datetime_utils.naive_utc_now()
last_update = _get_last_update_timestamp(cache_key)
if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS:
update_values["last_used"] = values.last_used
_set_last_update_timestamp(cache_key, now)
if values.quota_used is not None:
update_values["quota_used"] = values.quota_used
# Skip the current update operation if no updates are required.
if not update_values:
continue
# Build and execute the update statement
stmt = update(Provider).where(*where_conditions).values(**update_values)
result = cast(CursorResult, session.execute(stmt))
rows_affected = result.rowcount
logger.debug(
"Provider update (%s): %s rows affected. Filters: %s, Values: %s",
description,
rows_affected,
filters.model_dump(),
update_values,
)
# If no rows were affected for quota updates, log a warning
if rows_affected == 0 and description == "quota_deduction_update":
logger.warning(
"No Provider rows updated for quota deduction. "
"This may indicate quota limit exceeded or provider not found. "
"Filters: %s",
filters.model_dump(),
)
logger.debug("Successfully processed %s Provider updates", len(updates_to_perform))

View File

@@ -0,0 +1,4 @@
from blinker import signal
# sender: message, kwargs: conversation
message_was_created = signal("message-was-created")

View File

@@ -0,0 +1,7 @@
from blinker import signal
# sender: tenant
tenant_was_created = signal("tenant-was-created")
# sender: tenant
tenant_was_updated = signal("tenant-was-updated")