dify
This commit is contained in:
46
dify/api/services/trigger/app_trigger_service.py
Normal file
46
dify/api/services/trigger/app_trigger_service.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
AppTrigger management service.
|
||||
|
||||
Handles AppTrigger model CRUD operations and status management.
|
||||
This service centralizes all AppTrigger-related business logic.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.trigger import AppTrigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppTriggerService:
|
||||
"""Service for managing AppTrigger lifecycle and status."""
|
||||
|
||||
@staticmethod
|
||||
def mark_tenant_triggers_rate_limited(tenant_id: str) -> None:
|
||||
"""
|
||||
Mark all enabled triggers for a tenant as rate limited due to quota exceeded.
|
||||
|
||||
This method is called when a tenant's quota is exhausted. It updates all
|
||||
enabled triggers to RATE_LIMITED status to prevent further executions until
|
||||
quota is restored.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID whose triggers should be marked as rate limited
|
||||
|
||||
"""
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
session.execute(
|
||||
update(AppTrigger)
|
||||
.where(AppTrigger.tenant_id == tenant_id, AppTrigger.status == AppTriggerStatus.ENABLED)
|
||||
.values(status=AppTriggerStatus.RATE_LIMITED)
|
||||
)
|
||||
session.commit()
|
||||
logger.info("Marked all enabled triggers as rate limited for tenant %s", tenant_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark all enabled triggers as rate limited for tenant %s", tenant_id)
|
||||
312
dify/api/services/trigger/schedule_service.py
Normal file
312
dify/api/services/trigger/schedule_service.py
Normal file
@@ -0,0 +1,312 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig
|
||||
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError
|
||||
from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
from models.workflow import Workflow
|
||||
from services.errors.account import AccountNotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ScheduleService:
|
||||
@staticmethod
|
||||
def create_schedule(
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
config: ScheduleConfig,
|
||||
) -> WorkflowSchedulePlan:
|
||||
"""
|
||||
Create a new schedule with validated configuration.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
tenant_id: Tenant ID
|
||||
app_id: Application ID
|
||||
config: Validated schedule configuration
|
||||
|
||||
Returns:
|
||||
Created WorkflowSchedulePlan instance
|
||||
"""
|
||||
next_run_at = calculate_next_run_at(
|
||||
config.cron_expression,
|
||||
config.timezone,
|
||||
)
|
||||
|
||||
schedule = WorkflowSchedulePlan(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
node_id=config.node_id,
|
||||
cron_expression=config.cron_expression,
|
||||
timezone=config.timezone,
|
||||
next_run_at=next_run_at,
|
||||
)
|
||||
|
||||
session.add(schedule)
|
||||
session.flush()
|
||||
|
||||
return schedule
|
||||
|
||||
@staticmethod
|
||||
def update_schedule(
|
||||
session: Session,
|
||||
schedule_id: str,
|
||||
updates: SchedulePlanUpdate,
|
||||
) -> WorkflowSchedulePlan:
|
||||
"""
|
||||
Update an existing schedule with validated configuration.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
schedule_id: Schedule ID to update
|
||||
updates: Validated update configuration
|
||||
|
||||
Raises:
|
||||
ScheduleNotFoundError: If schedule not found
|
||||
|
||||
Returns:
|
||||
Updated WorkflowSchedulePlan instance
|
||||
"""
|
||||
schedule = session.get(WorkflowSchedulePlan, schedule_id)
|
||||
if not schedule:
|
||||
raise ScheduleNotFoundError(f"Schedule not found: {schedule_id}")
|
||||
|
||||
# If time-related fields are updated, synchronously update the next_run_at.
|
||||
time_fields_updated = False
|
||||
|
||||
if updates.node_id is not None:
|
||||
schedule.node_id = updates.node_id
|
||||
|
||||
if updates.cron_expression is not None:
|
||||
schedule.cron_expression = updates.cron_expression
|
||||
time_fields_updated = True
|
||||
|
||||
if updates.timezone is not None:
|
||||
schedule.timezone = updates.timezone
|
||||
time_fields_updated = True
|
||||
|
||||
if time_fields_updated:
|
||||
schedule.next_run_at = calculate_next_run_at(
|
||||
schedule.cron_expression,
|
||||
schedule.timezone,
|
||||
)
|
||||
|
||||
session.flush()
|
||||
return schedule
|
||||
|
||||
@staticmethod
|
||||
def delete_schedule(
|
||||
session: Session,
|
||||
schedule_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Delete a schedule plan.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
schedule_id: Schedule ID to delete
|
||||
"""
|
||||
schedule = session.get(WorkflowSchedulePlan, schedule_id)
|
||||
if not schedule:
|
||||
raise ScheduleNotFoundError(f"Schedule not found: {schedule_id}")
|
||||
|
||||
session.delete(schedule)
|
||||
session.flush()
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_owner(session: Session, tenant_id: str) -> Account:
|
||||
"""
|
||||
Returns an account to execute scheduled workflows on behalf of the tenant.
|
||||
Prioritizes owner over admin to ensure proper authorization hierarchy.
|
||||
"""
|
||||
result = session.execute(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.role == "owner")
|
||||
.limit(1)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not result:
|
||||
# Owner may not exist in some tenant configurations, fallback to admin
|
||||
result = session.execute(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.role == "admin")
|
||||
.limit(1)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if result:
|
||||
account = session.get(Account, result.account_id)
|
||||
if not account:
|
||||
raise AccountNotFoundError(f"Account not found: {result.account_id}")
|
||||
return account
|
||||
else:
|
||||
raise AccountNotFoundError(f"Account not found for tenant: {tenant_id}")
|
||||
|
||||
@staticmethod
|
||||
def update_next_run_at(
|
||||
session: Session,
|
||||
schedule_id: str,
|
||||
) -> datetime:
|
||||
"""
|
||||
Advances the schedule to its next execution time after a successful trigger.
|
||||
Uses current time as base to prevent missing executions during delays.
|
||||
"""
|
||||
schedule = session.get(WorkflowSchedulePlan, schedule_id)
|
||||
if not schedule:
|
||||
raise ScheduleNotFoundError(f"Schedule not found: {schedule_id}")
|
||||
|
||||
# Base on current time to handle execution delays gracefully
|
||||
next_run_at = calculate_next_run_at(
|
||||
schedule.cron_expression,
|
||||
schedule.timezone,
|
||||
)
|
||||
|
||||
schedule.next_run_at = next_run_at
|
||||
session.flush()
|
||||
return next_run_at
|
||||
|
||||
@staticmethod
|
||||
def to_schedule_config(node_config: Mapping[str, Any]) -> ScheduleConfig:
|
||||
"""
|
||||
Converts user-friendly visual schedule settings to cron expression.
|
||||
Maintains consistency with frontend UI expectations while supporting croniter's extended syntax.
|
||||
"""
|
||||
node_data = node_config.get("data", {})
|
||||
mode = node_data.get("mode", "visual")
|
||||
timezone = node_data.get("timezone", "UTC")
|
||||
node_id = node_config.get("id", "start")
|
||||
|
||||
cron_expression = None
|
||||
if mode == "cron":
|
||||
cron_expression = node_data.get("cron_expression")
|
||||
if not cron_expression:
|
||||
raise ScheduleConfigError("Cron expression is required for cron mode")
|
||||
elif mode == "visual":
|
||||
frequency = str(node_data.get("frequency"))
|
||||
if not frequency:
|
||||
raise ScheduleConfigError("Frequency is required for visual mode")
|
||||
visual_config = VisualConfig(**node_data.get("visual_config", {}))
|
||||
cron_expression = ScheduleService.visual_to_cron(frequency=frequency, visual_config=visual_config)
|
||||
if not cron_expression:
|
||||
raise ScheduleConfigError("Cron expression is required for visual mode")
|
||||
else:
|
||||
raise ScheduleConfigError(f"Invalid schedule mode: {mode}")
|
||||
return ScheduleConfig(node_id=node_id, cron_expression=cron_expression, timezone=timezone)
|
||||
|
||||
@staticmethod
|
||||
def extract_schedule_config(workflow: Workflow) -> ScheduleConfig | None:
|
||||
"""
|
||||
Extracts schedule configuration from workflow graph.
|
||||
|
||||
Searches for the first schedule trigger node in the workflow and converts
|
||||
its configuration (either visual or cron mode) into a unified ScheduleConfig.
|
||||
|
||||
Args:
|
||||
workflow: The workflow containing the graph definition
|
||||
|
||||
Returns:
|
||||
ScheduleConfig if a valid schedule node is found, None if no schedule node exists
|
||||
|
||||
Raises:
|
||||
ScheduleConfigError: If graph parsing fails or schedule configuration is invalid
|
||||
|
||||
Note:
|
||||
Currently only returns the first schedule node found.
|
||||
Multiple schedule nodes in the same workflow are not supported.
|
||||
"""
|
||||
try:
|
||||
graph_data = workflow.graph_dict
|
||||
except (json.JSONDecodeError, TypeError, AttributeError) as e:
|
||||
raise ScheduleConfigError(f"Failed to parse workflow graph: {e}")
|
||||
|
||||
if not graph_data:
|
||||
raise ScheduleConfigError("Workflow graph is empty")
|
||||
|
||||
nodes = graph_data.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
|
||||
if node_data.get("type") != NodeType.TRIGGER_SCHEDULE.value:
|
||||
continue
|
||||
|
||||
mode = node_data.get("mode", "visual")
|
||||
timezone = node_data.get("timezone", "UTC")
|
||||
node_id = node.get("id", "start")
|
||||
|
||||
cron_expression = None
|
||||
if mode == "cron":
|
||||
cron_expression = node_data.get("cron_expression")
|
||||
if not cron_expression:
|
||||
raise ScheduleConfigError("Cron expression is required for cron mode")
|
||||
elif mode == "visual":
|
||||
frequency = node_data.get("frequency")
|
||||
visual_config_dict = node_data.get("visual_config", {})
|
||||
visual_config = VisualConfig(**visual_config_dict)
|
||||
cron_expression = ScheduleService.visual_to_cron(frequency, visual_config)
|
||||
else:
|
||||
raise ScheduleConfigError(f"Invalid schedule mode: {mode}")
|
||||
|
||||
return ScheduleConfig(node_id=node_id, cron_expression=cron_expression, timezone=timezone)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def visual_to_cron(frequency: str, visual_config: VisualConfig) -> str:
|
||||
"""
|
||||
Converts user-friendly visual schedule settings to cron expression.
|
||||
Maintains consistency with frontend UI expectations while supporting croniter's extended syntax.
|
||||
"""
|
||||
if frequency == "hourly":
|
||||
if visual_config.on_minute is None:
|
||||
raise ScheduleConfigError("on_minute is required for hourly schedules")
|
||||
return f"{visual_config.on_minute} * * * *"
|
||||
|
||||
elif frequency == "daily":
|
||||
if not visual_config.time:
|
||||
raise ScheduleConfigError("time is required for daily schedules")
|
||||
hour, minute = convert_12h_to_24h(visual_config.time)
|
||||
return f"{minute} {hour} * * *"
|
||||
|
||||
elif frequency == "weekly":
|
||||
if not visual_config.time:
|
||||
raise ScheduleConfigError("time is required for weekly schedules")
|
||||
if not visual_config.weekdays:
|
||||
raise ScheduleConfigError("Weekdays are required for weekly schedules")
|
||||
hour, minute = convert_12h_to_24h(visual_config.time)
|
||||
weekday_map = {"sun": "0", "mon": "1", "tue": "2", "wed": "3", "thu": "4", "fri": "5", "sat": "6"}
|
||||
cron_weekdays = [weekday_map[day] for day in visual_config.weekdays]
|
||||
return f"{minute} {hour} * * {','.join(sorted(cron_weekdays))}"
|
||||
|
||||
elif frequency == "monthly":
|
||||
if not visual_config.time:
|
||||
raise ScheduleConfigError("time is required for monthly schedules")
|
||||
if not visual_config.monthly_days:
|
||||
raise ScheduleConfigError("Monthly days are required for monthly schedules")
|
||||
hour, minute = convert_12h_to_24h(visual_config.time)
|
||||
|
||||
numeric_days: list[int] = []
|
||||
has_last = False
|
||||
for day in visual_config.monthly_days:
|
||||
if day == "last":
|
||||
has_last = True
|
||||
else:
|
||||
numeric_days.append(day)
|
||||
|
||||
result_days = [str(d) for d in sorted(set(numeric_days))]
|
||||
if has_last:
|
||||
result_days.append("L")
|
||||
|
||||
return f"{minute} {hour} {','.join(result_days)} * *"
|
||||
|
||||
else:
|
||||
raise ScheduleConfigError(f"Unsupported frequency: {frequency}")
|
||||
690
dify/api/services/trigger/trigger_provider_service.py
Normal file
690
dify/api/services/trigger/trigger_provider_service.py
Normal file
@@ -0,0 +1,690 @@
|
||||
import json
|
||||
import logging
|
||||
import time as _time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.trigger.entities.api_entities import (
|
||||
TriggerProviderApiEntity,
|
||||
TriggerProviderSubscriptionApiEntity,
|
||||
)
|
||||
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.trigger.utils.encryption import (
|
||||
create_trigger_provider_encrypter_for_properties,
|
||||
create_trigger_provider_encrypter_for_subscription,
|
||||
delete_cache_for_subscription,
|
||||
)
|
||||
from core.trigger.utils.endpoint import generate_plugin_trigger_endpoint_url
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import (
|
||||
TriggerOAuthSystemClient,
|
||||
TriggerOAuthTenantClient,
|
||||
TriggerSubscription,
|
||||
WorkflowPluginTrigger,
|
||||
)
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerProviderService:
|
||||
"""Service for managing trigger providers and credentials"""
|
||||
|
||||
##########################
|
||||
# Trigger provider
|
||||
##########################
|
||||
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
|
||||
|
||||
@classmethod
|
||||
def get_trigger_provider(cls, tenant_id: str, provider: TriggerProviderID) -> TriggerProviderApiEntity:
|
||||
"""Get info for a trigger provider"""
|
||||
return TriggerManager.get_trigger_provider(tenant_id, provider).to_api_entity()
|
||||
|
||||
@classmethod
|
||||
def list_trigger_providers(cls, tenant_id: str) -> list[TriggerProviderApiEntity]:
|
||||
"""List all trigger providers for the current tenant"""
|
||||
return [provider.to_api_entity() for provider in TriggerManager.list_all_trigger_providers(tenant_id)]
|
||||
|
||||
@classmethod
|
||||
def list_trigger_provider_subscriptions(
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID
|
||||
) -> list[TriggerProviderSubscriptionApiEntity]:
|
||||
"""List all trigger subscriptions for the current tenant"""
|
||||
subscriptions: list[TriggerProviderSubscriptionApiEntity] = []
|
||||
workflows_in_use_map: dict[str, int] = {}
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get all subscriptions
|
||||
subscriptions_db = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
|
||||
.order_by(desc(TriggerSubscription.created_at))
|
||||
.all()
|
||||
)
|
||||
subscriptions = [subscription.to_api_entity() for subscription in subscriptions_db]
|
||||
if not subscriptions:
|
||||
return []
|
||||
usage_counts = (
|
||||
session.query(
|
||||
WorkflowPluginTrigger.subscription_id,
|
||||
func.count(func.distinct(WorkflowPluginTrigger.app_id)).label("app_count"),
|
||||
)
|
||||
.filter(
|
||||
WorkflowPluginTrigger.tenant_id == tenant_id,
|
||||
WorkflowPluginTrigger.subscription_id.in_([s.id for s in subscriptions]),
|
||||
)
|
||||
.group_by(WorkflowPluginTrigger.subscription_id)
|
||||
.all()
|
||||
)
|
||||
workflows_in_use_map = {str(row.subscription_id): int(row.app_count) for row in usage_counts}
|
||||
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
for subscription in subscriptions:
|
||||
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
subscription.credentials = dict(
|
||||
encrypter.mask_credentials(dict(encrypter.decrypt(subscription.credentials)))
|
||||
)
|
||||
subscription.properties = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.properties))))
|
||||
subscription.parameters = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.parameters))))
|
||||
count = workflows_in_use_map.get(subscription.id)
|
||||
subscription.workflows_in_use = count if count is not None else 0
|
||||
|
||||
return subscriptions
|
||||
|
||||
@classmethod
|
||||
def add_trigger_subscription(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
name: str,
|
||||
provider_id: TriggerProviderID,
|
||||
endpoint_id: str,
|
||||
credential_type: CredentialType,
|
||||
parameters: Mapping[str, Any],
|
||||
properties: Mapping[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
subscription_id: str | None = None,
|
||||
credential_expires_at: int = -1,
|
||||
expires_at: int = -1,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Add a new trigger provider with credentials.
|
||||
Supports multiple credential instances per provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier (e.g., "plugin_id/provider_name")
|
||||
:param credential_type: Type of credential (oauth or api_key)
|
||||
:param credentials: Credential data to encrypt and store
|
||||
:param name: Optional name for this credential instance
|
||||
:param expires_at: OAuth token expiration timestamp
|
||||
:return: Success response
|
||||
"""
|
||||
try:
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Use distributed lock to prevent race conditions
|
||||
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
# Check provider count limit
|
||||
provider_count = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
|
||||
.count()
|
||||
)
|
||||
|
||||
if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__:
|
||||
raise ValueError(
|
||||
f"Maximum number of providers ({cls.__MAX_TRIGGER_PROVIDER_COUNT__}) "
|
||||
f"reached for {provider_id}"
|
||||
)
|
||||
|
||||
# Check if name already exists
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Credential name '{name}' already exists for this provider")
|
||||
|
||||
credential_encrypter: ProviderConfigEncrypter | None = None
|
||||
if credential_type != CredentialType.UNAUTHORIZED:
|
||||
credential_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credential_schema_config(credential_type),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
properties_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_properties_schema(),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# Create provider record
|
||||
subscription = TriggerSubscription(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
endpoint_id=endpoint_id,
|
||||
provider_id=str(provider_id),
|
||||
parameters=dict(parameters),
|
||||
properties=dict(properties_encrypter.encrypt(dict(properties))),
|
||||
credentials=dict(credential_encrypter.encrypt(dict(credentials)))
|
||||
if credential_encrypter
|
||||
else {},
|
||||
credential_type=credential_type.value,
|
||||
credential_expires_at=credential_expires_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
subscription.id = subscription_id or str(uuid.uuid4())
|
||||
|
||||
session.add(subscription)
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"id": str(subscription.id),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to add trigger provider")
|
||||
raise ValueError(str(e))
|
||||
|
||||
@classmethod
|
||||
def get_subscription_by_id(cls, tenant_id: str, subscription_id: str | None = None) -> TriggerSubscription | None:
|
||||
"""
|
||||
Get a trigger subscription by the ID.
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
subscription: TriggerSubscription | None = None
|
||||
if subscription_id:
|
||||
subscription = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
)
|
||||
else:
|
||||
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id).first()
|
||||
if subscription:
|
||||
provider_controller = TriggerManager.get_trigger_provider(
|
||||
tenant_id, TriggerProviderID(subscription.provider_id)
|
||||
)
|
||||
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
subscription.credentials = dict(encrypter.decrypt(subscription.credentials))
|
||||
properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
|
||||
tenant_id=subscription.tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
subscription.properties = dict(properties_encrypter.decrypt(subscription.properties))
|
||||
return subscription
|
||||
|
||||
@classmethod
|
||||
def delete_trigger_provider(cls, session: Session, tenant_id: str, subscription_id: str):
|
||||
"""
|
||||
Delete a trigger provider subscription within an existing session.
|
||||
|
||||
:param session: Database session
|
||||
:param tenant_id: Tenant ID
|
||||
:param subscription_id: Subscription instance ID
|
||||
:return: Success response
|
||||
"""
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
|
||||
credential_type: CredentialType = CredentialType.of(subscription.credential_type)
|
||||
is_auto_created: bool = credential_type in [CredentialType.OAUTH2, CredentialType.API_KEY]
|
||||
if is_auto_created:
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
try:
|
||||
TriggerManager.unsubscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=subscription.user_id,
|
||||
provider_id=provider_id,
|
||||
subscription=subscription.to_entity(),
|
||||
credentials=encrypter.decrypt(subscription.credentials),
|
||||
credential_type=credential_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error unsubscribing trigger", exc_info=e)
|
||||
|
||||
# Clear cache
|
||||
session.delete(subscription)
|
||||
delete_cache_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=subscription.provider_id,
|
||||
subscription_id=subscription.id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def refresh_oauth_token(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
subscription_id: str,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Refresh OAuth token for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param subscription_id: Subscription instance ID
|
||||
:return: New token info
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
|
||||
if not subscription:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
|
||||
if subscription.credential_type != CredentialType.OAUTH2.value:
|
||||
raise ValueError("Only OAuth credentials can be refreshed")
|
||||
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
# Create encrypter
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# Decrypt current credentials
|
||||
current_credentials = encrypter.decrypt(subscription.credentials)
|
||||
|
||||
# Get OAuth client configuration
|
||||
redirect_uri = (
|
||||
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{subscription.provider_id}/trigger/callback"
|
||||
)
|
||||
system_credentials = cls.get_oauth_client(tenant_id, provider_id)
|
||||
|
||||
# Refresh token
|
||||
oauth_handler = OAuthHandler()
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=subscription.user_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=current_credentials,
|
||||
)
|
||||
|
||||
# Update credentials
|
||||
subscription.credentials = dict(encrypter.encrypt(dict(refreshed_credentials.credentials)))
|
||||
subscription.credential_expires_at = refreshed_credentials.expires_at
|
||||
session.commit()
|
||||
|
||||
# Clear cache
|
||||
cache.delete()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"expires_at": refreshed_credentials.expires_at,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def refresh_subscription(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
subscription_id: str,
|
||||
now: int | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Refresh trigger subscription if expired.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
subscription_id: Subscription instance ID
|
||||
now: Current timestamp, defaults to `int(time.time())`
|
||||
|
||||
Returns:
|
||||
Mapping with keys: `result` ("success"|"skipped") and `expires_at` (new or existing value)
|
||||
"""
|
||||
now_ts: int = int(now if now is not None else _time.time())
|
||||
|
||||
with Session(db.engine) as session:
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
)
|
||||
if subscription is None:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
|
||||
if subscription.expires_at == -1 or int(subscription.expires_at) > now_ts:
|
||||
logger.debug(
|
||||
"Subscription not due for refresh: tenant=%s id=%s expires_at=%s now=%s",
|
||||
tenant_id,
|
||||
subscription_id,
|
||||
subscription.expires_at,
|
||||
now_ts,
|
||||
)
|
||||
return {"result": "skipped", "expires_at": int(subscription.expires_at)}
|
||||
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
|
||||
# Decrypt credentials and properties for runtime
|
||||
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
properties_encrypter, properties_cache = create_trigger_provider_encrypter_for_properties(
|
||||
tenant_id=tenant_id,
|
||||
controller=controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
|
||||
decrypted_credentials = credential_encrypter.decrypt(subscription.credentials)
|
||||
decrypted_properties = properties_encrypter.decrypt(subscription.properties)
|
||||
|
||||
sub_entity: TriggerSubscriptionEntity = TriggerSubscriptionEntity(
|
||||
expires_at=int(subscription.expires_at),
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
||||
parameters=subscription.parameters,
|
||||
properties=decrypted_properties,
|
||||
)
|
||||
|
||||
refreshed: TriggerSubscriptionEntity = controller.refresh_trigger(
|
||||
subscription=sub_entity,
|
||||
credentials=decrypted_credentials,
|
||||
credential_type=CredentialType.of(subscription.credential_type),
|
||||
)
|
||||
|
||||
# Persist refreshed properties and expires_at
|
||||
subscription.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties)))
|
||||
subscription.expires_at = int(refreshed.expires_at)
|
||||
session.commit()
|
||||
properties_cache.delete()
|
||||
|
||||
logger.info(
|
||||
"Subscription refreshed (service): tenant=%s id=%s new_expires_at=%s",
|
||||
tenant_id,
|
||||
subscription_id,
|
||||
subscription.expires_at,
|
||||
)
|
||||
|
||||
return {"result": "success", "expires_at": int(refreshed.expires_at)}
|
||||
|
||||
@classmethod
|
||||
def get_oauth_client(cls, tenant_id: str, provider_id: TriggerProviderID) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Get OAuth client configuration for a provider.
|
||||
First tries tenant-level OAuth, then falls back to system OAuth.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier
|
||||
:return: OAuth client configuration or None
|
||||
"""
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
tenant_client: TriggerOAuthTenantClient | None = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
enabled=True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
oauth_params: Mapping[str, Any] | None = None
|
||||
if tenant_client:
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
oauth_params = encrypter.decrypt(dict(tenant_client.oauth_params))
|
||||
return oauth_params
|
||||
|
||||
is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
|
||||
if not is_verified:
|
||||
return None
|
||||
|
||||
# Check for system-level OAuth client
|
||||
system_client: TriggerOAuthSystemClient | None = (
|
||||
session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if system_client:
|
||||
try:
|
||||
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||
|
||||
return oauth_params
|
||||
|
||||
@classmethod
|
||||
def is_oauth_system_client_exists(cls, tenant_id: str, provider_id: TriggerProviderID) -> bool:
|
||||
"""
|
||||
Check if system OAuth client exists for a trigger provider.
|
||||
"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id=tenant_id, provider_id=provider_id)
|
||||
is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
|
||||
if not is_verified:
|
||||
return False
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
system_client: TriggerOAuthSystemClient | None = (
|
||||
session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
|
||||
.first()
|
||||
)
|
||||
return system_client is not None
|
||||
|
||||
@classmethod
|
||||
def save_custom_oauth_client_params(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
client_params: Mapping[str, Any] | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Save or update custom OAuth client parameters for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier
|
||||
:param client_params: OAuth client parameters (client_id, client_secret, etc.)
|
||||
:param enabled: Enable/disable the custom OAuth client
|
||||
:return: Success response
|
||||
"""
|
||||
if client_params is None and enabled is None:
|
||||
return {"result": "success"}
|
||||
|
||||
# Get provider controller to access schema
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Find existing custom client params
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Create new record if doesn't exist
|
||||
if custom_client is None:
|
||||
custom_client = TriggerOAuthTenantClient(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
)
|
||||
session.add(custom_client)
|
||||
|
||||
# Update client params if provided
|
||||
if client_params is None:
|
||||
custom_client.encrypted_oauth_params = json.dumps({})
|
||||
else:
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# Handle hidden values
|
||||
original_params = encrypter.decrypt(dict(custom_client.oauth_params))
|
||||
new_params: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
||||
for key, value in client_params.items()
|
||||
}
|
||||
custom_client.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
|
||||
cache.delete()
|
||||
|
||||
# Update enabled status if provided
|
||||
if enabled is not None:
|
||||
custom_client.enabled = enabled
|
||||
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def get_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> Mapping[str, Any]:
|
||||
"""
|
||||
Get custom OAuth client parameters for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier
|
||||
:return: Masked OAuth client parameters
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if custom_client is None:
|
||||
return {}
|
||||
|
||||
# Get provider controller to access schema
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
|
||||
# Create encrypter to decrypt and mask values
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return encrypter.mask_plugin_credentials(encrypter.decrypt(dict(custom_client.oauth_params)))
|
||||
|
||||
@classmethod
|
||||
def delete_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> Mapping[str, Any]:
|
||||
"""
|
||||
Delete custom OAuth client parameters for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier
|
||||
:return: Success response
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
session.query(TriggerOAuthTenantClient).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def is_oauth_custom_client_enabled(cls, tenant_id: str, provider_id: TriggerProviderID) -> bool:
|
||||
"""
|
||||
Check if custom OAuth client is enabled for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier
|
||||
:return: True if enabled, False otherwise
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
enabled=True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return custom_client is not None
|
||||
|
||||
@classmethod
|
||||
def get_subscription_by_endpoint(cls, endpoint_id: str) -> TriggerSubscription | None:
|
||||
"""
|
||||
Get a trigger subscription by the endpoint ID.
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first()
|
||||
if not subscription:
|
||||
return None
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
|
||||
)
|
||||
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=subscription.tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
subscription.credentials = dict(credential_encrypter.decrypt(subscription.credentials))
|
||||
|
||||
properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
|
||||
tenant_id=subscription.tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
subscription.properties = dict(properties_encrypter.decrypt(subscription.properties))
|
||||
return subscription
|
||||
65
dify/api/services/trigger/trigger_request_service.py
Normal file
65
dify/api/services/trigger/trigger_request_service.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Request
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from core.plugin.utils.http_parser import deserialize_request, serialize_request
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class TriggerHttpRequestCachingService:
|
||||
"""
|
||||
Service for caching trigger requests.
|
||||
"""
|
||||
|
||||
_TRIGGER_STORAGE_PATH = "triggers"
|
||||
|
||||
@classmethod
|
||||
def get_request(cls, request_id: str) -> Request:
|
||||
"""
|
||||
Get the request object from the storage.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request.
|
||||
|
||||
Returns:
|
||||
The request object.
|
||||
"""
|
||||
return deserialize_request(storage.load_once(f"{cls._TRIGGER_STORAGE_PATH}/{request_id}.raw"))
|
||||
|
||||
@classmethod
|
||||
def get_payload(cls, request_id: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Get the payload from the storage.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request.
|
||||
|
||||
Returns:
|
||||
The payload.
|
||||
"""
|
||||
return TypeAdapter(Mapping[str, Any]).validate_json(
|
||||
storage.load_once(f"{cls._TRIGGER_STORAGE_PATH}/{request_id}.payload")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def persist_request(cls, request_id: str, request: Request) -> None:
|
||||
"""
|
||||
Persist the request in the storage.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request.
|
||||
request: The request object.
|
||||
"""
|
||||
storage.save(f"{cls._TRIGGER_STORAGE_PATH}/{request_id}.raw", serialize_request(request))
|
||||
|
||||
@classmethod
|
||||
def persist_payload(cls, request_id: str, payload: Mapping[str, Any]) -> None:
|
||||
"""
|
||||
Persist the payload in the storage.
|
||||
"""
|
||||
storage.save(
|
||||
f"{cls._TRIGGER_STORAGE_PATH}/{request_id}.payload",
|
||||
TypeAdapter(Mapping[str, Any]).dump_json(payload), # type: ignore
|
||||
)
|
||||
304
dify/api/services/trigger/trigger_service.py
Normal file
304
dify/api/services/trigger/trigger_service.py
Normal file
@@ -0,0 +1,304 @@
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Request, Response
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import TriggerDispatchResponse, TriggerInvokeEventResponse
|
||||
from core.plugin.impl.exc import PluginNotFoundError
|
||||
from core.trigger.debug.events import PluginTriggerDebugEvent
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import App
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerSubscription, WorkflowPluginTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
|
||||
from services.workflow.entities import PluginTriggerDispatchData
|
||||
from tasks.trigger_processing_tasks import dispatch_triggered_workflows_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerService:
|
||||
__TEMPORARY_ENDPOINT_EXPIRE_MS__ = 5 * 60 * 1000
|
||||
__ENDPOINT_REQUEST_CACHE_COUNT__ = 10
|
||||
__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000
|
||||
__PLUGIN_TRIGGER_NODE_CACHE_KEY__ = "plugin_trigger_nodes"
|
||||
MAX_PLUGIN_TRIGGER_NODES_PER_WORKFLOW = 5 # Maximum allowed plugin trigger nodes per workflow
|
||||
|
||||
@classmethod
|
||||
def invoke_trigger_event(
|
||||
cls, tenant_id: str, user_id: str, node_config: Mapping[str, Any], event: PluginTriggerDebugEvent
|
||||
) -> TriggerInvokeEventResponse:
|
||||
"""Invoke a trigger event."""
|
||||
subscription: TriggerSubscription | None = TriggerProviderService.get_subscription_by_id(
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=event.subscription_id,
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError("Subscription not found")
|
||||
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(node_config.get("data", {}))
|
||||
request = TriggerHttpRequestCachingService.get_request(event.request_id)
|
||||
payload = TriggerHttpRequestCachingService.get_payload(event.request_id)
|
||||
# invoke triger
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id, TriggerProviderID(subscription.provider_id)
|
||||
)
|
||||
return TriggerManager.invoke_trigger_event(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=TriggerProviderID(event.provider_id),
|
||||
event_name=event.name,
|
||||
parameters=node_data.resolve_parameters(
|
||||
parameter_schemas=provider_controller.get_event_parameters(event_name=event.name)
|
||||
),
|
||||
credentials=subscription.credentials,
|
||||
credential_type=CredentialType.of(subscription.credential_type),
|
||||
subscription=subscription.to_entity(),
|
||||
request=request,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def process_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
||||
"""
|
||||
Extract and process data from incoming endpoint request.
|
||||
|
||||
Args:
|
||||
endpoint_id: Endpoint ID
|
||||
request: Request
|
||||
"""
|
||||
timestamp = int(time.time())
|
||||
subscription: TriggerSubscription | None = None
|
||||
try:
|
||||
subscription = TriggerProviderService.get_subscription_by_endpoint(endpoint_id)
|
||||
except PluginNotFoundError:
|
||||
return Response(status=404, response="Trigger provider not found")
|
||||
except Exception:
|
||||
return Response(status=500, response="Failed to get subscription by endpoint")
|
||||
|
||||
if not subscription:
|
||||
return None
|
||||
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=subscription.tenant_id, provider_id=provider_id
|
||||
)
|
||||
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=subscription.tenant_id,
|
||||
controller=controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
dispatch_response: TriggerDispatchResponse = controller.dispatch(
|
||||
request=request,
|
||||
subscription=subscription.to_entity(),
|
||||
credentials=encrypter.decrypt(subscription.credentials),
|
||||
credential_type=CredentialType.of(subscription.credential_type),
|
||||
)
|
||||
|
||||
if dispatch_response.events:
|
||||
request_id = f"trigger_request_{timestamp}_{secrets.token_hex(6)}"
|
||||
|
||||
# save the request and payload to storage as persistent data
|
||||
TriggerHttpRequestCachingService.persist_request(request_id, request)
|
||||
TriggerHttpRequestCachingService.persist_payload(request_id, dispatch_response.payload)
|
||||
|
||||
# Validate event names
|
||||
for event_name in dispatch_response.events:
|
||||
if controller.get_event(event_name) is None:
|
||||
logger.error(
|
||||
"Event name %s not found in provider %s for endpoint %s",
|
||||
event_name,
|
||||
subscription.provider_id,
|
||||
endpoint_id,
|
||||
)
|
||||
raise ValueError(f"Event name {event_name} not found in provider {subscription.provider_id}")
|
||||
|
||||
plugin_trigger_dispatch_data = PluginTriggerDispatchData(
|
||||
user_id=dispatch_response.user_id,
|
||||
tenant_id=subscription.tenant_id,
|
||||
endpoint_id=endpoint_id,
|
||||
provider_id=subscription.provider_id,
|
||||
subscription_id=subscription.id,
|
||||
timestamp=timestamp,
|
||||
events=list(dispatch_response.events),
|
||||
request_id=request_id,
|
||||
)
|
||||
dispatch_data = plugin_trigger_dispatch_data.model_dump(mode="json")
|
||||
dispatch_triggered_workflows_async.delay(dispatch_data)
|
||||
|
||||
logger.info(
|
||||
"Queued async dispatching for %d triggers on endpoint %s with request_id %s",
|
||||
len(dispatch_response.events),
|
||||
endpoint_id,
|
||||
request_id,
|
||||
)
|
||||
return dispatch_response.response
|
||||
|
||||
@classmethod
|
||||
def sync_plugin_trigger_relationships(cls, app: App, workflow: Workflow):
|
||||
"""
|
||||
Sync plugin trigger relationships in DB.
|
||||
|
||||
1. Check if the workflow has any plugin trigger nodes
|
||||
2. Fetch the nodes from DB, see if there were any plugin trigger records already
|
||||
3. Diff the nodes and the plugin trigger records, create/update/delete the records as needed
|
||||
|
||||
Approach:
|
||||
Frequent DB operations may cause performance issues, using Redis to cache it instead.
|
||||
If any record exists, cache it.
|
||||
|
||||
Limits:
|
||||
- Maximum 5 plugin trigger nodes per workflow
|
||||
"""
|
||||
|
||||
class Cache(BaseModel):
|
||||
"""
|
||||
Cache model for plugin trigger nodes
|
||||
"""
|
||||
|
||||
record_id: str
|
||||
node_id: str
|
||||
provider_id: str
|
||||
event_name: str
|
||||
subscription_id: str
|
||||
|
||||
# Walk nodes to find plugin triggers
|
||||
nodes_in_graph: list[Mapping[str, Any]] = []
|
||||
for node_id, node_config in workflow.walk_nodes(NodeType.TRIGGER_PLUGIN):
|
||||
# Extract plugin trigger configuration from node
|
||||
plugin_id = node_config.get("plugin_id", "")
|
||||
provider_id = node_config.get("provider_id", "")
|
||||
event_name = node_config.get("event_name", "")
|
||||
subscription_id = node_config.get("subscription_id", "")
|
||||
|
||||
if not subscription_id:
|
||||
continue
|
||||
|
||||
nodes_in_graph.append(
|
||||
{
|
||||
"node_id": node_id,
|
||||
"plugin_id": plugin_id,
|
||||
"provider_id": provider_id,
|
||||
"event_name": event_name,
|
||||
"subscription_id": subscription_id,
|
||||
}
|
||||
)
|
||||
|
||||
# Check plugin trigger node limit
|
||||
if len(nodes_in_graph) > cls.MAX_PLUGIN_TRIGGER_NODES_PER_WORKFLOW:
|
||||
raise ValueError(
|
||||
f"Workflow exceeds maximum plugin trigger node limit. "
|
||||
f"Found {len(nodes_in_graph)} plugin trigger nodes, "
|
||||
f"maximum allowed is {cls.MAX_PLUGIN_TRIGGER_NODES_PER_WORKFLOW}"
|
||||
)
|
||||
|
||||
not_found_in_cache: list[Mapping[str, Any]] = []
|
||||
for node_info in nodes_in_graph:
|
||||
node_id = node_info["node_id"]
|
||||
# firstly check if the node exists in cache
|
||||
if not redis_client.get(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{app.id}:{node_id}"):
|
||||
not_found_in_cache.append(node_info)
|
||||
continue
|
||||
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
# lock the concurrent plugin trigger creation
|
||||
redis_client.lock(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
|
||||
# fetch the non-cached nodes from DB
|
||||
all_records = session.scalars(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.app_id == app.id,
|
||||
WorkflowPluginTrigger.tenant_id == app.tenant_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
nodes_id_in_db = {node.node_id: node for node in all_records}
|
||||
nodes_id_in_graph = {node["node_id"] for node in nodes_in_graph}
|
||||
|
||||
# get the nodes not found both in cache and DB
|
||||
nodes_not_found = [
|
||||
node_info for node_info in not_found_in_cache if node_info["node_id"] not in nodes_id_in_db
|
||||
]
|
||||
|
||||
# create new plugin trigger records
|
||||
for node_info in nodes_not_found:
|
||||
plugin_trigger = WorkflowPluginTrigger(
|
||||
app_id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
node_id=node_info["node_id"],
|
||||
provider_id=node_info["provider_id"],
|
||||
event_name=node_info["event_name"],
|
||||
subscription_id=node_info["subscription_id"],
|
||||
)
|
||||
session.add(plugin_trigger)
|
||||
session.flush() # Get the ID for caching
|
||||
|
||||
cache = Cache(
|
||||
record_id=plugin_trigger.id,
|
||||
node_id=node_info["node_id"],
|
||||
provider_id=node_info["provider_id"],
|
||||
event_name=node_info["event_name"],
|
||||
subscription_id=node_info["subscription_id"],
|
||||
)
|
||||
redis_client.set(
|
||||
f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{app.id}:{node_info['node_id']}",
|
||||
cache.model_dump_json(),
|
||||
ex=60 * 60,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Update existing records if subscription_id changed
|
||||
for node_info in nodes_in_graph:
|
||||
node_id = node_info["node_id"]
|
||||
if node_id in nodes_id_in_db:
|
||||
existing_record = nodes_id_in_db[node_id]
|
||||
if (
|
||||
existing_record.subscription_id != node_info["subscription_id"]
|
||||
or existing_record.provider_id != node_info["provider_id"]
|
||||
or existing_record.event_name != node_info["event_name"]
|
||||
):
|
||||
existing_record.subscription_id = node_info["subscription_id"]
|
||||
existing_record.provider_id = node_info["provider_id"]
|
||||
existing_record.event_name = node_info["event_name"]
|
||||
session.add(existing_record)
|
||||
|
||||
# Update cache
|
||||
cache = Cache(
|
||||
record_id=existing_record.id,
|
||||
node_id=node_id,
|
||||
provider_id=node_info["provider_id"],
|
||||
event_name=node_info["event_name"],
|
||||
subscription_id=node_info["subscription_id"],
|
||||
)
|
||||
redis_client.set(
|
||||
f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{app.id}:{node_id}",
|
||||
cache.model_dump_json(),
|
||||
ex=60 * 60,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# delete the nodes not found in the graph
|
||||
for node_id in nodes_id_in_db:
|
||||
if node_id not in nodes_id_in_graph:
|
||||
session.delete(nodes_id_in_db[node_id])
|
||||
redis_client.delete(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{app.id}:{node_id}")
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to sync plugin trigger relationships for app %s", app.id)
|
||||
raise
|
||||
finally:
|
||||
redis_client.delete(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:apps:{app.id}:lock")
|
||||
@@ -0,0 +1,492 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import Request, Response
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import TriggerDispatchResponse
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.trigger.entities.api_entities import SubscriptionBuilderApiEntity
|
||||
from core.trigger.entities.entities import (
|
||||
RequestLog,
|
||||
Subscription,
|
||||
SubscriptionBuilder,
|
||||
SubscriptionBuilderUpdater,
|
||||
SubscriptionConstructor,
|
||||
)
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.trigger.utils.encryption import masked_credentials
|
||||
from core.trigger.utils.endpoint import generate_plugin_trigger_endpoint_url
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderService:
|
||||
"""Service for managing trigger providers and credentials"""
|
||||
|
||||
##########################
|
||||
# Trigger provider
|
||||
##########################
|
||||
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
|
||||
|
||||
##########################
|
||||
# Builder endpoint
|
||||
##########################
|
||||
__BUILDER_CACHE_EXPIRE_SECONDS__ = 30 * 60
|
||||
|
||||
__VALIDATION_REQUEST_CACHE_COUNT__ = 10
|
||||
__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__ = 30 * 60
|
||||
|
||||
##########################
|
||||
# Distributed lock
|
||||
##########################
|
||||
__LOCK_EXPIRE_SECONDS__ = 30
|
||||
|
||||
@classmethod
|
||||
def encode_cache_key(cls, subscription_id: str) -> str:
|
||||
return f"trigger:subscription:builder:{subscription_id}"
|
||||
|
||||
@classmethod
|
||||
def encode_lock_key(cls, subscription_id: str) -> str:
|
||||
return f"trigger:subscription:builder:lock:{subscription_id}"
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def acquire_builder_lock(cls, subscription_id: str):
|
||||
"""
|
||||
Acquire a distributed lock for a subscription builder.
|
||||
|
||||
:param subscription_id: The subscription builder ID
|
||||
"""
|
||||
lock_key = cls.encode_lock_key(subscription_id)
|
||||
with redis_client.lock(lock_key, timeout=cls.__LOCK_EXPIRE_SECONDS__):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
def verify_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
) -> Mapping[str, Any]:
|
||||
"""Verify a trigger subscription builder"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
|
||||
if subscription_builder.credential_type == CredentialType.OAUTH2:
|
||||
return {"verified": bool(subscription_builder.credentials)}
|
||||
|
||||
if subscription_builder.credential_type == CredentialType.API_KEY:
|
||||
credentials_to_validate = subscription_builder.credentials
|
||||
try:
|
||||
provider_controller.validate_credentials(user_id, credentials_to_validate)
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
raise ValueError(f"Invalid credentials: {e}")
|
||||
return {"verified": True}
|
||||
|
||||
return {"verified": True}
|
||||
|
||||
@classmethod
|
||||
def build_trigger_subscription_builder(
|
||||
cls, tenant_id: str, user_id: str, provider_id: TriggerProviderID, subscription_builder_id: str
|
||||
) -> None:
|
||||
"""Build a trigger subscription builder"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
# Acquire lock to prevent concurrent build operations
|
||||
with cls.acquire_builder_lock(subscription_builder_id):
|
||||
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
|
||||
if not subscription_builder.name:
|
||||
raise ValueError("Subscription builder name is required")
|
||||
|
||||
credential_type = CredentialType.of(
|
||||
subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value
|
||||
)
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
# manually create
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
subscription_id=subscription_builder.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription_builder.properties,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
else:
|
||||
# automatically create
|
||||
subscription: Subscription = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription_builder.endpoint_id),
|
||||
parameters=subscription_builder.parameters,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
subscription_id=subscription_builder.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription.properties,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
)
|
||||
|
||||
# Delete the builder after successful subscription creation
|
||||
cache_key = cls.encode_cache_key(subscription_builder_id)
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
@classmethod
|
||||
def create_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
credential_type: CredentialType,
|
||||
) -> SubscriptionBuilderApiEntity:
|
||||
"""
|
||||
Add a new trigger subscription validation.
|
||||
"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
subscription_constructor: SubscriptionConstructor | None = provider_controller.get_subscription_constructor()
|
||||
subscription_id = str(uuid.uuid4())
|
||||
subscription_builder = SubscriptionBuilder(
|
||||
id=subscription_id,
|
||||
name=None,
|
||||
endpoint_id=subscription_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=str(provider_id),
|
||||
parameters=subscription_constructor.get_default_parameters() if subscription_constructor else {},
|
||||
properties=provider_controller.get_subscription_default_properties(),
|
||||
credentials={},
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=-1,
|
||||
expires_at=-1,
|
||||
)
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
redis_client.setex(cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder.model_dump_json())
|
||||
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder)
|
||||
|
||||
@classmethod
|
||||
def update_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
subscription_builder_updater: SubscriptionBuilderUpdater,
|
||||
) -> SubscriptionBuilderApiEntity:
|
||||
"""
|
||||
Update a trigger subscription validation.
|
||||
"""
|
||||
subscription_id = subscription_builder_id
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
# Acquire lock to prevent concurrent updates
|
||||
with cls.acquire_builder_lock(subscription_id):
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
|
||||
raise ValueError(f"Subscription {subscription_id} expired or not found")
|
||||
|
||||
subscription_builder_updater.update(subscription_builder_cache)
|
||||
|
||||
redis_client.setex(
|
||||
cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json()
|
||||
)
|
||||
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder_cache)
|
||||
|
||||
@classmethod
|
||||
def update_and_verify_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
subscription_builder_updater: SubscriptionBuilderUpdater,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Atomically update and verify a subscription builder.
|
||||
This ensures the verification is done on the exact data that was just updated.
|
||||
"""
|
||||
subscription_id = subscription_builder_id
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
# Acquire lock for the entire update + verify operation
|
||||
with cls.acquire_builder_lock(subscription_id):
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
|
||||
raise ValueError(f"Subscription {subscription_id} expired or not found")
|
||||
|
||||
# Update
|
||||
subscription_builder_updater.update(subscription_builder_cache)
|
||||
redis_client.setex(
|
||||
cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json()
|
||||
)
|
||||
|
||||
# Verify (using the just-updated data)
|
||||
if subscription_builder_cache.credential_type == CredentialType.OAUTH2:
|
||||
return {"verified": bool(subscription_builder_cache.credentials)}
|
||||
|
||||
if subscription_builder_cache.credential_type == CredentialType.API_KEY:
|
||||
credentials_to_validate = subscription_builder_cache.credentials
|
||||
try:
|
||||
provider_controller.validate_credentials(user_id, credentials_to_validate)
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
raise ValueError(f"Invalid credentials: {e}")
|
||||
return {"verified": True}
|
||||
|
||||
return {"verified": True}
|
||||
|
||||
@classmethod
|
||||
def update_and_build_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
subscription_builder_updater: SubscriptionBuilderUpdater,
|
||||
) -> None:
|
||||
"""
|
||||
Atomically update and build a subscription builder.
|
||||
This ensures the build uses the exact data that was just updated.
|
||||
"""
|
||||
subscription_id = subscription_builder_id
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
# Acquire lock for the entire update + build operation
|
||||
with cls.acquire_builder_lock(subscription_id):
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
|
||||
raise ValueError(f"Subscription {subscription_id} expired or not found")
|
||||
|
||||
# Update
|
||||
subscription_builder_updater.update(subscription_builder_cache)
|
||||
redis_client.setex(
|
||||
cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json()
|
||||
)
|
||||
|
||||
# Re-fetch to ensure we have the latest data
|
||||
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
|
||||
if not subscription_builder.name:
|
||||
raise ValueError("Subscription builder name is required")
|
||||
|
||||
# Build
|
||||
credential_type = CredentialType.of(
|
||||
subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value
|
||||
)
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
# manually create
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
subscription_id=subscription_builder.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription_builder.properties,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
else:
|
||||
# automatically create
|
||||
subscription: Subscription = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription_builder.endpoint_id),
|
||||
parameters=subscription_builder.parameters,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
subscription_id=subscription_builder.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription.properties,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
)
|
||||
|
||||
# Delete the builder after successful subscription creation
|
||||
cache_key = cls.encode_cache_key(subscription_builder_id)
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
@classmethod
|
||||
def builder_to_api_entity(
|
||||
cls, controller: PluginTriggerProviderController, entity: SubscriptionBuilder
|
||||
) -> SubscriptionBuilderApiEntity:
|
||||
credential_type = CredentialType.of(entity.credential_type or CredentialType.UNAUTHORIZED.value)
|
||||
return SubscriptionBuilderApiEntity(
|
||||
id=entity.id,
|
||||
name=entity.name or "",
|
||||
provider=entity.provider_id,
|
||||
endpoint=generate_plugin_trigger_endpoint_url(entity.endpoint_id),
|
||||
parameters=entity.parameters,
|
||||
properties=entity.properties,
|
||||
credential_type=credential_type,
|
||||
credentials=masked_credentials(
|
||||
schemas=controller.get_credentials_schema(credential_type),
|
||||
credentials=entity.credentials,
|
||||
)
|
||||
if controller.get_subscription_constructor()
|
||||
else {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_subscription_builder(cls, endpoint_id: str) -> SubscriptionBuilder | None:
|
||||
"""
|
||||
Get a trigger subscription by the endpoint ID.
|
||||
"""
|
||||
cache_key = cls.encode_cache_key(endpoint_id)
|
||||
subscription_cache = redis_client.get(cache_key)
|
||||
if subscription_cache:
|
||||
return SubscriptionBuilder.model_validate(json.loads(subscription_cache))
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def append_log(cls, endpoint_id: str, request: Request, response: Response) -> None:
|
||||
"""Append validation request log to Redis."""
|
||||
log = RequestLog(
|
||||
id=str(uuid.uuid4()),
|
||||
endpoint=endpoint_id,
|
||||
request={
|
||||
"method": request.method,
|
||||
"url": request.url,
|
||||
"headers": dict(request.headers),
|
||||
"data": request.get_data(as_text=True),
|
||||
},
|
||||
response={
|
||||
"status_code": response.status_code,
|
||||
"headers": dict(response.headers),
|
||||
"data": response.get_data(as_text=True),
|
||||
},
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
key = f"trigger:subscription:builder:logs:{endpoint_id}"
|
||||
logs = json.loads(redis_client.get(key) or "[]")
|
||||
logs.append(log.model_dump(mode="json"))
|
||||
|
||||
# Keep last N logs
|
||||
logs = logs[-cls.__VALIDATION_REQUEST_CACHE_COUNT__ :]
|
||||
redis_client.setex(key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__, json.dumps(logs, default=str))
|
||||
|
||||
@classmethod
|
||||
def list_logs(cls, endpoint_id: str) -> list[RequestLog]:
|
||||
"""List request logs for validation endpoint."""
|
||||
key = f"trigger:subscription:builder:logs:{endpoint_id}"
|
||||
logs_json = redis_client.get(key)
|
||||
if not logs_json:
|
||||
return []
|
||||
return [RequestLog.model_validate(log) for log in json.loads(logs_json)]
|
||||
|
||||
@classmethod
|
||||
def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
||||
"""
|
||||
Process a temporary endpoint request.
|
||||
|
||||
:param endpoint_id: The endpoint identifier
|
||||
:param request: The Flask request object
|
||||
:return: The Flask response object
|
||||
"""
|
||||
# check if validation endpoint exists
|
||||
subscription_builder: SubscriptionBuilder | None = cls.get_subscription_builder(endpoint_id)
|
||||
if not subscription_builder:
|
||||
return None
|
||||
|
||||
# response to validation endpoint
|
||||
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=subscription_builder.tenant_id, provider_id=TriggerProviderID(subscription_builder.provider_id)
|
||||
)
|
||||
try:
|
||||
dispatch_response: TriggerDispatchResponse = controller.dispatch(
|
||||
request=request,
|
||||
subscription=subscription_builder.to_subscription(),
|
||||
credentials={},
|
||||
credential_type=CredentialType.UNAUTHORIZED,
|
||||
)
|
||||
response: Response = dispatch_response.response
|
||||
# append the request log
|
||||
cls.append_log(
|
||||
endpoint_id=endpoint_id,
|
||||
request=request,
|
||||
response=response,
|
||||
)
|
||||
return response
|
||||
except Exception:
|
||||
logger.exception("Error during validation endpoint dispatch for endpoint_id=%s", endpoint_id)
|
||||
error_response = Response(status=500, response="An internal error has occurred.")
|
||||
cls.append_log(endpoint_id=endpoint_id, request=request, response=error_response)
|
||||
return error_response
|
||||
|
||||
@classmethod
|
||||
def get_subscription_builder_by_id(cls, subscription_builder_id: str) -> SubscriptionBuilderApiEntity:
|
||||
"""Get a trigger subscription builder API entity."""
|
||||
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
return cls.builder_to_api_entity(
|
||||
controller=TriggerManager.get_trigger_provider(
|
||||
subscription_builder.tenant_id, TriggerProviderID(subscription_builder.provider_id)
|
||||
),
|
||||
entity=subscription_builder,
|
||||
)
|
||||
@@ -0,0 +1,70 @@
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.trigger import AppTrigger, WorkflowPluginTrigger
|
||||
|
||||
|
||||
class TriggerSubscriptionOperatorService:
|
||||
@classmethod
|
||||
def get_subscriber_triggers(
|
||||
cls, tenant_id: str, subscription_id: str, event_name: str
|
||||
) -> list[WorkflowPluginTrigger]:
|
||||
"""
|
||||
Get WorkflowPluginTriggers for a subscription and trigger.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
subscription_id: Subscription ID
|
||||
event_name: Event name
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
subscribers = session.scalars(
|
||||
select(WorkflowPluginTrigger)
|
||||
.join(
|
||||
AppTrigger,
|
||||
and_(
|
||||
AppTrigger.tenant_id == WorkflowPluginTrigger.tenant_id,
|
||||
AppTrigger.app_id == WorkflowPluginTrigger.app_id,
|
||||
AppTrigger.node_id == WorkflowPluginTrigger.node_id,
|
||||
),
|
||||
)
|
||||
.where(
|
||||
WorkflowPluginTrigger.tenant_id == tenant_id,
|
||||
WorkflowPluginTrigger.subscription_id == subscription_id,
|
||||
WorkflowPluginTrigger.event_name == event_name,
|
||||
AppTrigger.status == AppTriggerStatus.ENABLED,
|
||||
)
|
||||
).all()
|
||||
return list(subscribers)
|
||||
|
||||
@classmethod
|
||||
def delete_plugin_trigger_by_subscription(
|
||||
cls,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
subscription_id: str,
|
||||
) -> None:
|
||||
"""Delete a plugin trigger by tenant_id and subscription_id within an existing session
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
tenant_id: The tenant ID
|
||||
subscription_id: The subscription ID
|
||||
|
||||
Raises:
|
||||
NotFound: If plugin trigger not found
|
||||
"""
|
||||
# Find plugin trigger using indexed columns
|
||||
plugin_trigger = session.scalar(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.tenant_id == tenant_id,
|
||||
WorkflowPluginTrigger.subscription_id == subscription_id,
|
||||
)
|
||||
)
|
||||
|
||||
if not plugin_trigger:
|
||||
return
|
||||
|
||||
session.delete(plugin_trigger)
|
||||
902
dify/api/services/trigger/webhook_service.py
Normal file
902
dify/api/services/trigger/webhook_service.py
Normal file
@@ -0,0 +1,902 @@
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import secrets
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from flask import request
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import RequestEntityTooLarge
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import FileTransferMethod
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.enums import NodeType
|
||||
from enums.quota_type import QuotaType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory
|
||||
from models.enums import AppTriggerStatus, AppTriggerType
|
||||
from models.model import App
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.workflow.entities import WebhookTriggerData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebhookService:
|
||||
"""Service for handling webhook operations."""
|
||||
|
||||
__WEBHOOK_NODE_CACHE_KEY__ = "webhook_nodes"
|
||||
MAX_WEBHOOK_NODES_PER_WORKFLOW = 5 # Maximum allowed webhook nodes per workflow
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_key(key: str) -> str:
|
||||
"""Normalize external keys (headers/params) to workflow-safe variables."""
|
||||
if not isinstance(key, str):
|
||||
return key
|
||||
return key.replace("-", "_")
|
||||
|
||||
@classmethod
|
||||
def get_webhook_trigger_and_workflow(
|
||||
cls, webhook_id: str, is_debug: bool = False
|
||||
) -> tuple[WorkflowWebhookTrigger, Workflow, Mapping[str, Any]]:
|
||||
"""Get webhook trigger, workflow, and node configuration.
|
||||
|
||||
Args:
|
||||
webhook_id: The webhook ID to look up
|
||||
is_debug: If True, use the draft workflow graph and skip the trigger enabled status check
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- WorkflowWebhookTrigger: The webhook trigger object
|
||||
- Workflow: The associated workflow object
|
||||
- Mapping[str, Any]: The node configuration data
|
||||
|
||||
Raises:
|
||||
ValueError: If webhook not found, app trigger not found, trigger disabled, or workflow not found
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).first()
|
||||
)
|
||||
if not webhook_trigger:
|
||||
raise ValueError(f"Webhook not found: {webhook_id}")
|
||||
|
||||
if is_debug:
|
||||
workflow = (
|
||||
session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.app_id == webhook_trigger.app_id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
# Check if the corresponding AppTrigger exists
|
||||
app_trigger = (
|
||||
session.query(AppTrigger)
|
||||
.filter(
|
||||
AppTrigger.app_id == webhook_trigger.app_id,
|
||||
AppTrigger.node_id == webhook_trigger.node_id,
|
||||
AppTrigger.trigger_type == AppTriggerType.TRIGGER_WEBHOOK,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app_trigger:
|
||||
raise ValueError(f"App trigger not found for webhook {webhook_id}")
|
||||
|
||||
# Only check enabled status if not in debug mode
|
||||
|
||||
if app_trigger.status == AppTriggerStatus.RATE_LIMITED:
|
||||
raise ValueError(
|
||||
f"Webhook trigger is rate limited for webhook {webhook_id}, please upgrade your plan."
|
||||
)
|
||||
|
||||
if app_trigger.status != AppTriggerStatus.ENABLED:
|
||||
raise ValueError(f"Webhook trigger is disabled for webhook {webhook_id}")
|
||||
|
||||
# Get workflow
|
||||
workflow = (
|
||||
session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.app_id == webhook_trigger.app_id,
|
||||
Workflow.version != Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}")
|
||||
|
||||
node_config = workflow.get_node_config_by_id(webhook_trigger.node_id)
|
||||
|
||||
return webhook_trigger, workflow, node_config
|
||||
|
||||
@classmethod
|
||||
def extract_and_validate_webhook_data(
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: Mapping[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Extract and validate webhook data in a single unified process.
|
||||
|
||||
Args:
|
||||
webhook_trigger: The webhook trigger object containing metadata
|
||||
node_config: The node configuration containing validation rules
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Processed and validated webhook data with correct types
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails (HTTP method mismatch, missing required fields, type errors)
|
||||
"""
|
||||
# Extract raw data first
|
||||
raw_data = cls.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# Validate HTTP metadata (method, content-type)
|
||||
node_data = node_config.get("data", {})
|
||||
validation_result = cls._validate_http_metadata(raw_data, node_data)
|
||||
if not validation_result["valid"]:
|
||||
raise ValueError(validation_result["error"])
|
||||
|
||||
# Process and validate data according to configuration
|
||||
processed_data = cls._process_and_validate_data(raw_data, node_data)
|
||||
|
||||
return processed_data
|
||||
|
||||
@classmethod
|
||||
def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> dict[str, Any]:
|
||||
"""Extract raw data from incoming webhook request without type conversion.
|
||||
|
||||
Args:
|
||||
webhook_trigger: The webhook trigger object for file processing context
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Raw webhook data containing:
|
||||
- method: HTTP method
|
||||
- headers: Request headers
|
||||
- query_params: Query parameters as strings
|
||||
- body: Request body (varies by content type; JSON parsing errors raise ValueError)
|
||||
- files: Uploaded files (if any)
|
||||
"""
|
||||
cls._validate_content_length()
|
||||
|
||||
data = {
|
||||
"method": request.method,
|
||||
"headers": dict(request.headers),
|
||||
"query_params": dict(request.args),
|
||||
"body": {},
|
||||
"files": {},
|
||||
}
|
||||
|
||||
# Extract and normalize content type
|
||||
content_type = cls._extract_content_type(dict(request.headers))
|
||||
|
||||
# Route to appropriate extractor based on content type
|
||||
extractors = {
|
||||
"application/json": cls._extract_json_body,
|
||||
"application/x-www-form-urlencoded": cls._extract_form_body,
|
||||
"multipart/form-data": lambda: cls._extract_multipart_body(webhook_trigger),
|
||||
"application/octet-stream": lambda: cls._extract_octet_stream_body(webhook_trigger),
|
||||
"text/plain": cls._extract_text_body,
|
||||
}
|
||||
|
||||
extractor = extractors.get(content_type)
|
||||
if not extractor:
|
||||
# Default to text/plain for unknown content types
|
||||
logger.warning("Unknown Content-Type: %s, treating as text/plain", content_type)
|
||||
extractor = cls._extract_text_body
|
||||
|
||||
# Extract body and files
|
||||
body_data, files_data = extractor()
|
||||
data["body"] = body_data
|
||||
data["files"] = files_data
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Process and validate webhook data according to node configuration.
|
||||
|
||||
Args:
|
||||
raw_data: Raw webhook data from extraction
|
||||
node_data: Node configuration containing validation and type rules
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Processed data with validated types
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails or required fields are missing
|
||||
"""
|
||||
result = raw_data.copy()
|
||||
|
||||
# Validate and process headers
|
||||
cls._validate_required_headers(raw_data["headers"], node_data.get("headers", []))
|
||||
|
||||
# Process query parameters with type conversion and validation
|
||||
result["query_params"] = cls._process_parameters(
|
||||
raw_data["query_params"], node_data.get("params", []), is_form_data=True
|
||||
)
|
||||
|
||||
# Process body parameters based on content type
|
||||
configured_content_type = node_data.get("content_type", "application/json").lower()
|
||||
result["body"] = cls._process_body_parameters(
|
||||
raw_data["body"], node_data.get("body", []), configured_content_type
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _validate_content_length(cls) -> None:
|
||||
"""Validate request content length against maximum allowed size."""
|
||||
content_length = request.content_length
|
||||
if content_length and content_length > dify_config.WEBHOOK_REQUEST_BODY_MAX_SIZE:
|
||||
raise RequestEntityTooLarge(
|
||||
f"Webhook request too large: {content_length} bytes exceeds maximum allowed size "
|
||||
f"of {dify_config.WEBHOOK_REQUEST_BODY_MAX_SIZE} bytes"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_json_body(cls) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Extract JSON body from request.
|
||||
|
||||
Returns:
|
||||
tuple: (body_data, files_data) where:
|
||||
- body_data: Parsed JSON content
|
||||
- files_data: Empty dict (JSON requests don't contain files)
|
||||
|
||||
Raises:
|
||||
ValueError: If JSON parsing fails
|
||||
"""
|
||||
raw_body = request.get_data(cache=True)
|
||||
if not raw_body or raw_body.strip() == b"":
|
||||
return {}, {}
|
||||
|
||||
try:
|
||||
body = orjson.loads(raw_body)
|
||||
except orjson.JSONDecodeError as exc:
|
||||
logger.warning("Failed to parse JSON body: %s", exc)
|
||||
raise ValueError(f"Invalid JSON body: {exc}") from exc
|
||||
return body, {}
|
||||
|
||||
@classmethod
|
||||
def _extract_form_body(cls) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Extract form-urlencoded body from request.
|
||||
|
||||
Returns:
|
||||
tuple: (body_data, files_data) where:
|
||||
- body_data: Form data as key-value pairs
|
||||
- files_data: Empty dict (form-urlencoded requests don't contain files)
|
||||
"""
|
||||
return dict(request.form), {}
|
||||
|
||||
@classmethod
|
||||
def _extract_multipart_body(cls, webhook_trigger: WorkflowWebhookTrigger) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Extract multipart/form-data body and files from request.
|
||||
|
||||
Args:
|
||||
webhook_trigger: Webhook trigger for file processing context
|
||||
|
||||
Returns:
|
||||
tuple: (body_data, files_data) where:
|
||||
- body_data: Form data as key-value pairs
|
||||
- files_data: Processed file objects indexed by field name
|
||||
"""
|
||||
body = dict(request.form)
|
||||
files = cls._process_file_uploads(request.files, webhook_trigger) if request.files else {}
|
||||
return body, files
|
||||
|
||||
@classmethod
|
||||
def _extract_octet_stream_body(
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Extract binary data as file from request.
|
||||
|
||||
Args:
|
||||
webhook_trigger: Webhook trigger for file processing context
|
||||
|
||||
Returns:
|
||||
tuple: (body_data, files_data) where:
|
||||
- body_data: Dict with 'raw' key containing file object or None
|
||||
- files_data: Empty dict
|
||||
"""
|
||||
try:
|
||||
file_content = request.get_data()
|
||||
if file_content:
|
||||
file_obj = cls._create_file_from_binary(file_content, "application/octet-stream", webhook_trigger)
|
||||
return {"raw": file_obj.to_dict()}, {}
|
||||
else:
|
||||
return {"raw": None}, {}
|
||||
except Exception:
|
||||
logger.exception("Failed to process octet-stream data")
|
||||
return {"raw": None}, {}
|
||||
|
||||
@classmethod
|
||||
def _extract_text_body(cls) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Extract text/plain body from request.
|
||||
|
||||
Returns:
|
||||
tuple: (body_data, files_data) where:
|
||||
- body_data: Dict with 'raw' key containing text content
|
||||
- files_data: Empty dict (text requests don't contain files)
|
||||
"""
|
||||
try:
|
||||
body = {"raw": request.get_data(as_text=True)}
|
||||
except Exception:
|
||||
logger.warning("Failed to extract text body")
|
||||
body = {"raw": ""}
|
||||
return body, {}
|
||||
|
||||
@classmethod
|
||||
def _process_file_uploads(
|
||||
cls, files: Mapping[str, FileStorage], webhook_trigger: WorkflowWebhookTrigger
|
||||
) -> dict[str, Any]:
|
||||
"""Process file uploads using ToolFileManager.
|
||||
|
||||
Args:
|
||||
files: Flask request files object containing uploaded files
|
||||
webhook_trigger: Webhook trigger for tenant and user context
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Processed file objects indexed by field name
|
||||
"""
|
||||
processed_files = {}
|
||||
|
||||
for name, file in files.items():
|
||||
if file and file.filename:
|
||||
try:
|
||||
file_content = file.read()
|
||||
mimetype = file.content_type or mimetypes.guess_type(file.filename)[0] or "application/octet-stream"
|
||||
file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger)
|
||||
processed_files[name] = file_obj.to_dict()
|
||||
except Exception:
|
||||
logger.exception("Failed to process file upload '%s'", name)
|
||||
# Continue processing other files
|
||||
|
||||
return processed_files
|
||||
|
||||
@classmethod
|
||||
def _create_file_from_binary(
|
||||
cls, file_content: bytes, mimetype: str, webhook_trigger: WorkflowWebhookTrigger
|
||||
) -> Any:
|
||||
"""Create a file object from binary content using ToolFileManager.
|
||||
|
||||
Args:
|
||||
file_content: The binary content of the file
|
||||
mimetype: The MIME type of the file
|
||||
webhook_trigger: Webhook trigger for tenant and user context
|
||||
|
||||
Returns:
|
||||
Any: A file object built from the binary content
|
||||
"""
|
||||
tool_file_manager = ToolFileManager()
|
||||
|
||||
# Create file using ToolFileManager
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=webhook_trigger.created_by,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=file_content,
|
||||
mimetype=mimetype,
|
||||
)
|
||||
|
||||
# Build File object
|
||||
mapping = {
|
||||
"tool_file_id": tool_file.id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE.value,
|
||||
}
|
||||
return file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _process_parameters(
|
||||
cls, raw_params: dict[str, str], param_configs: list, is_form_data: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""Process parameters with unified validation and type conversion.
|
||||
|
||||
Args:
|
||||
raw_params: Raw parameter values as strings
|
||||
param_configs: List of parameter configuration dictionaries
|
||||
is_form_data: Whether the parameters are from form data (requiring string conversion)
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Processed parameters with validated types
|
||||
|
||||
Raises:
|
||||
ValueError: If required parameters are missing or validation fails
|
||||
"""
|
||||
processed = {}
|
||||
configured_params = {config.get("name", ""): config for config in param_configs}
|
||||
|
||||
# Process configured parameters
|
||||
for param_config in param_configs:
|
||||
name = param_config.get("name", "")
|
||||
param_type = param_config.get("type", SegmentType.STRING)
|
||||
required = param_config.get("required", False)
|
||||
|
||||
# Check required parameters
|
||||
if required and name not in raw_params:
|
||||
raise ValueError(f"Required parameter missing: {name}")
|
||||
|
||||
if name in raw_params:
|
||||
raw_value = raw_params[name]
|
||||
processed[name] = cls._validate_and_convert_value(name, raw_value, param_type, is_form_data)
|
||||
|
||||
# Include unconfigured parameters as strings
|
||||
for name, value in raw_params.items():
|
||||
if name not in configured_params:
|
||||
processed[name] = value
|
||||
|
||||
return processed
|
||||
|
||||
@classmethod
|
||||
def _process_body_parameters(
|
||||
cls, raw_body: dict[str, Any], body_configs: list, content_type: str
|
||||
) -> dict[str, Any]:
|
||||
"""Process body parameters based on content type and configuration.
|
||||
|
||||
Args:
|
||||
raw_body: Raw body data from request
|
||||
body_configs: List of body parameter configuration dictionaries
|
||||
content_type: The request content type
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Processed body parameters with validated types
|
||||
|
||||
Raises:
|
||||
ValueError: If required body parameters are missing or validation fails
|
||||
"""
|
||||
if content_type in ["text/plain", "application/octet-stream"]:
|
||||
# For text/plain and octet-stream, validate required content exists
|
||||
if body_configs and any(config.get("required", False) for config in body_configs):
|
||||
raw_content = raw_body.get("raw")
|
||||
if not raw_content:
|
||||
raise ValueError(f"Required body content missing for {content_type} request")
|
||||
return raw_body
|
||||
|
||||
# For structured data (JSON, form-data, etc.)
|
||||
processed = {}
|
||||
configured_params = {config.get("name", ""): config for config in body_configs}
|
||||
|
||||
for body_config in body_configs:
|
||||
name = body_config.get("name", "")
|
||||
param_type = body_config.get("type", SegmentType.STRING)
|
||||
required = body_config.get("required", False)
|
||||
|
||||
# Handle file parameters for multipart data
|
||||
if param_type == SegmentType.FILE and content_type == "multipart/form-data":
|
||||
# File validation is handled separately in extract phase
|
||||
continue
|
||||
|
||||
# Check required parameters
|
||||
if required and name not in raw_body:
|
||||
raise ValueError(f"Required body parameter missing: {name}")
|
||||
|
||||
if name in raw_body:
|
||||
raw_value = raw_body[name]
|
||||
is_form_data = content_type in ["application/x-www-form-urlencoded", "multipart/form-data"]
|
||||
processed[name] = cls._validate_and_convert_value(name, raw_value, param_type, is_form_data)
|
||||
|
||||
# Include unconfigured parameters
|
||||
for name, value in raw_body.items():
|
||||
if name not in configured_params:
|
||||
processed[name] = value
|
||||
|
||||
return processed
|
||||
|
||||
@classmethod
|
||||
def _validate_and_convert_value(cls, param_name: str, value: Any, param_type: str, is_form_data: bool) -> Any:
|
||||
"""Unified validation and type conversion for parameter values.
|
||||
|
||||
Args:
|
||||
param_name: Name of the parameter for error reporting
|
||||
value: The value to validate and convert
|
||||
param_type: The expected parameter type (SegmentType)
|
||||
is_form_data: Whether the value is from form data (requiring string conversion)
|
||||
|
||||
Returns:
|
||||
Any: The validated and converted value
|
||||
|
||||
Raises:
|
||||
ValueError: If validation or conversion fails
|
||||
"""
|
||||
try:
|
||||
if is_form_data:
|
||||
# Form data comes as strings and needs conversion
|
||||
return cls._convert_form_value(param_name, value, param_type)
|
||||
else:
|
||||
# JSON data should already be in correct types, just validate
|
||||
return cls._validate_json_value(param_name, value, param_type)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Parameter '{param_name}' validation failed: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def _convert_form_value(cls, param_name: str, value: str, param_type: str) -> Any:
|
||||
"""Convert form data string values to specified types.
|
||||
|
||||
Args:
|
||||
param_name: Name of the parameter for error reporting
|
||||
value: The string value to convert
|
||||
param_type: The target type to convert to (SegmentType)
|
||||
|
||||
Returns:
|
||||
Any: The converted value in the appropriate type
|
||||
|
||||
Raises:
|
||||
ValueError: If the value cannot be converted to the specified type
|
||||
"""
|
||||
if param_type == SegmentType.STRING:
|
||||
return value
|
||||
elif param_type == SegmentType.NUMBER:
|
||||
if not cls._can_convert_to_number(value):
|
||||
raise ValueError(f"Cannot convert '{value}' to number")
|
||||
numeric_value = float(value)
|
||||
return int(numeric_value) if numeric_value.is_integer() else numeric_value
|
||||
elif param_type == SegmentType.BOOLEAN:
|
||||
lower_value = value.lower()
|
||||
bool_map = {"true": True, "false": False, "1": True, "0": False, "yes": True, "no": False}
|
||||
if lower_value not in bool_map:
|
||||
raise ValueError(f"Cannot convert '{value}' to boolean")
|
||||
return bool_map[lower_value]
|
||||
else:
|
||||
raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'")
|
||||
|
||||
@classmethod
|
||||
def _validate_json_value(cls, param_name: str, value: Any, param_type: str) -> Any:
|
||||
"""Validate JSON values against expected types.
|
||||
|
||||
Args:
|
||||
param_name: Name of the parameter for error reporting
|
||||
value: The value to validate
|
||||
param_type: The expected parameter type (SegmentType)
|
||||
|
||||
Returns:
|
||||
Any: The validated value (unchanged if valid)
|
||||
|
||||
Raises:
|
||||
ValueError: If the value type doesn't match the expected type
|
||||
"""
|
||||
type_validators = {
|
||||
SegmentType.STRING: (lambda v: isinstance(v, str), "string"),
|
||||
SegmentType.NUMBER: (lambda v: isinstance(v, (int, float)), "number"),
|
||||
SegmentType.BOOLEAN: (lambda v: isinstance(v, bool), "boolean"),
|
||||
SegmentType.OBJECT: (lambda v: isinstance(v, dict), "object"),
|
||||
SegmentType.ARRAY_STRING: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, str) for item in v),
|
||||
"array of strings",
|
||||
),
|
||||
SegmentType.ARRAY_NUMBER: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, (int, float)) for item in v),
|
||||
"array of numbers",
|
||||
),
|
||||
SegmentType.ARRAY_BOOLEAN: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, bool) for item in v),
|
||||
"array of booleans",
|
||||
),
|
||||
SegmentType.ARRAY_OBJECT: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, dict) for item in v),
|
||||
"array of objects",
|
||||
),
|
||||
}
|
||||
|
||||
validator_info = type_validators.get(SegmentType(param_type))
|
||||
if not validator_info:
|
||||
logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name)
|
||||
return value
|
||||
|
||||
validator, expected_type = validator_info
|
||||
if not validator(value):
|
||||
actual_type = type(value).__name__
|
||||
raise ValueError(f"Expected {expected_type}, got {actual_type}")
|
||||
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _validate_required_headers(cls, headers: dict[str, Any], header_configs: list) -> None:
|
||||
"""Validate required headers are present.
|
||||
|
||||
Args:
|
||||
headers: Request headers dictionary
|
||||
header_configs: List of header configuration dictionaries
|
||||
|
||||
Raises:
|
||||
ValueError: If required headers are missing
|
||||
"""
|
||||
headers_lower = {k.lower(): v for k, v in headers.items()}
|
||||
headers_sanitized = {cls._sanitize_key(k).lower(): v for k, v in headers.items()}
|
||||
for header_config in header_configs:
|
||||
if header_config.get("required", False):
|
||||
header_name = header_config.get("name", "")
|
||||
sanitized_name = cls._sanitize_key(header_name).lower()
|
||||
if header_name.lower() not in headers_lower and sanitized_name not in headers_sanitized:
|
||||
raise ValueError(f"Required header missing: {header_name}")
|
||||
|
||||
@classmethod
|
||||
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate HTTP method and content-type.
|
||||
|
||||
Args:
|
||||
webhook_data: Extracted webhook data containing method and headers
|
||||
node_data: Node configuration containing expected method and content-type
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Validation result with 'valid' key and optional 'error' key
|
||||
"""
|
||||
# Validate HTTP method
|
||||
configured_method = node_data.get("method", "get").upper()
|
||||
request_method = webhook_data["method"].upper()
|
||||
if configured_method != request_method:
|
||||
return cls._validation_error(f"HTTP method mismatch. Expected {configured_method}, got {request_method}")
|
||||
|
||||
# Validate Content-type
|
||||
configured_content_type = node_data.get("content_type", "application/json").lower()
|
||||
request_content_type = cls._extract_content_type(webhook_data["headers"])
|
||||
|
||||
if configured_content_type != request_content_type:
|
||||
return cls._validation_error(
|
||||
f"Content-type mismatch. Expected {configured_content_type}, got {request_content_type}"
|
||||
)
|
||||
|
||||
return {"valid": True}
|
||||
|
||||
@classmethod
|
||||
def _extract_content_type(cls, headers: dict[str, Any]) -> str:
|
||||
"""Extract and normalize content-type from headers.
|
||||
|
||||
Args:
|
||||
headers: Request headers dictionary
|
||||
|
||||
Returns:
|
||||
str: Normalized content-type (main type without parameters)
|
||||
"""
|
||||
content_type = headers.get("Content-Type", "").lower()
|
||||
if not content_type:
|
||||
content_type = headers.get("content-type", "application/json").lower()
|
||||
# Extract the main content type (ignore parameters like boundary)
|
||||
return content_type.split(";")[0].strip()
|
||||
|
||||
@classmethod
|
||||
def _validation_error(cls, error_message: str) -> dict[str, Any]:
|
||||
"""Create a standard validation error response.
|
||||
|
||||
Args:
|
||||
error_message: The error message to include
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Validation error response with 'valid' and 'error' keys
|
||||
"""
|
||||
return {"valid": False, "error": error_message}
|
||||
|
||||
@classmethod
|
||||
def _can_convert_to_number(cls, value: str) -> bool:
|
||||
"""Check if a string can be converted to a number."""
|
||||
try:
|
||||
float(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def build_workflow_inputs(cls, webhook_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Construct workflow inputs payload from webhook data.
|
||||
|
||||
Args:
|
||||
webhook_data: Processed webhook data containing headers, query params, and body
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Workflow inputs formatted for execution
|
||||
"""
|
||||
return {
|
||||
"webhook_data": webhook_data,
|
||||
"webhook_headers": webhook_data.get("headers", {}),
|
||||
"webhook_query_params": webhook_data.get("query_params", {}),
|
||||
"webhook_body": webhook_data.get("body", {}),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def trigger_workflow_execution(
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: dict[str, Any], workflow: Workflow
|
||||
) -> None:
|
||||
"""Trigger workflow execution via AsyncWorkflowService.
|
||||
|
||||
Args:
|
||||
webhook_trigger: The webhook trigger object
|
||||
webhook_data: Processed webhook data for workflow inputs
|
||||
workflow: The workflow to execute
|
||||
|
||||
Raises:
|
||||
ValueError: If tenant owner is not found
|
||||
Exception: If workflow execution fails
|
||||
"""
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Prepare inputs for the webhook node
|
||||
# The webhook node expects webhook_data in the inputs
|
||||
workflow_inputs = cls.build_workflow_inputs(webhook_data)
|
||||
|
||||
# Create trigger data
|
||||
trigger_data = WebhookTriggerData(
|
||||
app_id=webhook_trigger.app_id,
|
||||
workflow_id=workflow.id,
|
||||
root_node_id=webhook_trigger.node_id, # Start from the webhook node
|
||||
inputs=workflow_inputs,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
)
|
||||
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.TRIGGER,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
app_id=webhook_trigger.app_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# consume quota before triggering workflow execution
|
||||
try:
|
||||
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
|
||||
logger.info(
|
||||
"Tenant %s rate limited, skipping webhook trigger %s",
|
||||
webhook_trigger.tenant_id,
|
||||
webhook_trigger.webhook_id,
|
||||
)
|
||||
raise
|
||||
|
||||
# Trigger workflow execution asynchronously
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session,
|
||||
end_user,
|
||||
trigger_data,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def generate_webhook_response(cls, node_config: Mapping[str, Any]) -> tuple[dict[str, Any], int]:
|
||||
"""Generate HTTP response based on node configuration.
|
||||
|
||||
Args:
|
||||
node_config: Node configuration containing response settings
|
||||
|
||||
Returns:
|
||||
tuple[dict[str, Any], int]: Response data and HTTP status code
|
||||
"""
|
||||
node_data = node_config.get("data", {})
|
||||
|
||||
# Get configured status code and response body
|
||||
status_code = node_data.get("status_code", 200)
|
||||
response_body = node_data.get("response_body", "")
|
||||
|
||||
# Parse response body as JSON if it's valid JSON, otherwise return as text
|
||||
try:
|
||||
if response_body:
|
||||
try:
|
||||
response_data = (
|
||||
json.loads(response_body)
|
||||
if response_body.strip().startswith(("{", "["))
|
||||
else {"message": response_body}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
response_data = {"message": response_body}
|
||||
else:
|
||||
response_data = {"status": "success", "message": "Webhook processed successfully"}
|
||||
except:
|
||||
response_data = {"message": response_body or "Webhook processed successfully"}
|
||||
|
||||
return response_data, status_code
|
||||
|
||||
@classmethod
|
||||
def sync_webhook_relationships(cls, app: App, workflow: Workflow):
|
||||
"""
|
||||
Sync webhook relationships in DB.
|
||||
|
||||
1. Check if the workflow has any webhook trigger nodes
|
||||
2. Fetch the nodes from DB, see if there were any webhook records already
|
||||
3. Diff the nodes and the webhook records, create/update/delete the webhook records as needed
|
||||
|
||||
Approach:
|
||||
Frequent DB operations may cause performance issues, using Redis to cache it instead.
|
||||
If any record exists, cache it.
|
||||
|
||||
Limits:
|
||||
- Maximum 5 webhook nodes per workflow
|
||||
"""
|
||||
|
||||
class Cache(BaseModel):
|
||||
"""
|
||||
Cache model for webhook nodes
|
||||
"""
|
||||
|
||||
record_id: str
|
||||
node_id: str
|
||||
webhook_id: str
|
||||
|
||||
nodes_id_in_graph = [node_id for node_id, _ in workflow.walk_nodes(NodeType.TRIGGER_WEBHOOK)]
|
||||
|
||||
# Check webhook node limit
|
||||
if len(nodes_id_in_graph) > cls.MAX_WEBHOOK_NODES_PER_WORKFLOW:
|
||||
raise ValueError(
|
||||
f"Workflow exceeds maximum webhook node limit. "
|
||||
f"Found {len(nodes_id_in_graph)} webhook nodes, maximum allowed is {cls.MAX_WEBHOOK_NODES_PER_WORKFLOW}"
|
||||
)
|
||||
|
||||
not_found_in_cache: list[str] = []
|
||||
for node_id in nodes_id_in_graph:
|
||||
# firstly check if the node exists in cache
|
||||
if not redis_client.get(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}"):
|
||||
not_found_in_cache.append(node_id)
|
||||
continue
|
||||
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
# lock the concurrent webhook trigger creation
|
||||
redis_client.lock(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
|
||||
# fetch the non-cached nodes from DB
|
||||
all_records = session.scalars(
|
||||
select(WorkflowWebhookTrigger).where(
|
||||
WorkflowWebhookTrigger.app_id == app.id,
|
||||
WorkflowWebhookTrigger.tenant_id == app.tenant_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
nodes_id_in_db = {node.node_id: node for node in all_records}
|
||||
|
||||
# get the nodes not found both in cache and DB
|
||||
nodes_not_found = [node_id for node_id in not_found_in_cache if node_id not in nodes_id_in_db]
|
||||
|
||||
# create new webhook records
|
||||
for node_id in nodes_not_found:
|
||||
webhook_record = WorkflowWebhookTrigger(
|
||||
app_id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
node_id=node_id,
|
||||
webhook_id=cls.generate_webhook_id(),
|
||||
created_by=app.created_by,
|
||||
)
|
||||
session.add(webhook_record)
|
||||
session.flush()
|
||||
cache = Cache(record_id=webhook_record.id, node_id=node_id, webhook_id=webhook_record.webhook_id)
|
||||
redis_client.set(
|
||||
f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}", cache.model_dump_json(), ex=60 * 60
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# delete the nodes not found in the graph
|
||||
for node_id in nodes_id_in_db:
|
||||
if node_id not in nodes_id_in_graph:
|
||||
session.delete(nodes_id_in_db[node_id])
|
||||
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}")
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to sync webhook relationships for app %s", app.id)
|
||||
raise
|
||||
finally:
|
||||
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock")
|
||||
|
||||
@classmethod
|
||||
def generate_webhook_id(cls) -> str:
|
||||
"""
|
||||
Generate unique 24-character webhook ID
|
||||
|
||||
Deduplication is not needed, DB already has unique constraint on webhook_id.
|
||||
"""
|
||||
# Generate 24-character random string
|
||||
return secrets.token_urlsafe(18)[:24] # token_urlsafe gives base64url, take first 24 chars
|
||||
Reference in New Issue
Block a user