This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from . import errors
__all__ = ["errors"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,97 @@
import copy
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_CONTEXT,
CHAT_APP_CHAT_PROMPT_CONFIG,
CHAT_APP_COMPLETION_PROMPT_CONFIG,
COMPLETION_APP_CHAT_PROMPT_CONFIG,
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
CONTEXT,
)
from models.model import AppMode
class AdvancedPromptTemplateService:
@classmethod
def get_prompt(cls, args: dict):
app_mode = args["app_mode"]
model_mode = args["model_mode"]
model_name = args["model_name"]
has_context = args["has_context"]
if "baichuan" in model_name.lower():
return cls.get_baichuan_prompt(app_mode, model_mode, has_context)
else:
return cls.get_common_prompt(app_mode, model_mode, has_context)
@classmethod
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
context_prompt = copy.deepcopy(CONTEXT)
if app_mode == AppMode.CHAT:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
elif app_mode == AppMode.COMPLETION:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(
copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
)
# default return empty dict
return {}
@classmethod
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str):
if has_context == "true":
prompt_template["completion_prompt_config"]["prompt"]["text"] = (
context + prompt_template["completion_prompt_config"]["prompt"]["text"]
)
return prompt_template
@classmethod
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str):
if has_context == "true":
prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
)
return prompt_template
@classmethod
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
if app_mode == AppMode.CHAT:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
elif app_mode == AppMode.COMPLETION:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),
has_context,
baichuan_context_prompt,
)
elif model_mode == "chat":
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
# default return empty dict
return {}

View File

@@ -0,0 +1,175 @@
import threading
from typing import Any
import pytz
import contexts
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
from core.plugin.impl.agent import PluginAgentClient
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from libs.login import current_user
from models import Account
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
class AgentService:
@classmethod
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str):
"""
Service to get agent logs
"""
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
conversation: Conversation | None = (
db.session.query(Conversation)
.where(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
)
.first()
)
if not conversation:
raise ValueError(f"Conversation not found: {conversation_id}")
message: Message | None = (
db.session.query(Message)
.where(
Message.id == message_id,
Message.conversation_id == conversation_id,
)
.first()
)
if not message:
raise ValueError(f"Message not found: {message_id}")
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if conversation.from_end_user_id:
# only select name field
executor = (
db.session.query(EndUser, EndUser.name).where(EndUser.id == conversation.from_end_user_id).first()
)
else:
executor = db.session.query(Account, Account.name).where(Account.id == conversation.from_account_id).first()
if executor:
executor = executor.name
else:
executor = "Unknown"
assert isinstance(current_user, Account)
assert current_user.timezone is not None
timezone = pytz.timezone(current_user.timezone)
app_model_config = app_model.app_model_config
if not app_model_config:
raise ValueError("App model config not found")
result: dict[str, Any] = {
"meta": {
"status": "success",
"executor": executor,
"start_time": message.created_at.astimezone(timezone).isoformat(),
"elapsed_time": message.provider_response_latency,
"total_tokens": message.answer_tokens + message.message_tokens,
"agent_mode": app_model_config.agent_mode_dict.get("strategy", "react"),
"iterations": len(agent_thoughts),
},
"iterations": [],
"files": message.message_files,
}
agent_config = AgentConfigManager.convert(app_model_config.to_dict())
if not agent_config:
raise ValueError("Agent config not found")
agent_tools = agent_config.tools or []
def find_agent_tool(tool_name: str):
for agent_tool in agent_tools:
if agent_tool.tool_name == tool_name:
return agent_tool
for agent_thought in agent_thoughts:
tools = agent_thought.tools
tool_labels = agent_thought.tool_labels
tool_meta = agent_thought.tool_meta
tool_inputs = agent_thought.tool_inputs_dict
tool_outputs = agent_thought.tool_outputs_dict or {}
tool_calls = []
for tool in tools:
tool_name = tool
tool_label = tool_labels.get(tool_name, tool_name)
tool_input = tool_inputs.get(tool_name, {})
tool_output = tool_outputs.get(tool_name, {})
tool_meta_data = tool_meta.get(tool_name, {})
tool_config = tool_meta_data.get("tool_config", {})
if tool_config.get("tool_provider_type", "") != "dataset-retrieval":
tool_icon = ToolManager.get_tool_icon(
tenant_id=app_model.tenant_id,
provider_type=tool_config.get("tool_provider_type", ""),
provider_id=tool_config.get("tool_provider", ""),
)
if not tool_icon:
tool_entity = find_agent_tool(tool_name)
if tool_entity:
tool_icon = ToolManager.get_tool_icon(
tenant_id=app_model.tenant_id,
provider_type=tool_entity.provider_type,
provider_id=tool_entity.provider_id,
)
else:
tool_icon = ""
tool_calls.append(
{
"status": "success" if not tool_meta_data.get("error") else "error",
"error": tool_meta_data.get("error"),
"time_cost": tool_meta_data.get("time_cost", 0),
"tool_name": tool_name,
"tool_label": tool_label,
"tool_input": tool_input,
"tool_output": tool_output,
"tool_parameters": tool_meta_data.get("tool_parameters", {}),
"tool_icon": tool_icon,
}
)
result["iterations"].append(
{
"tokens": agent_thought.tokens,
"tool_calls": tool_calls,
"tool_raw": {
"inputs": agent_thought.tool_input,
"outputs": agent_thought.observation,
},
"thought": agent_thought.thought,
"created_at": agent_thought.created_at.isoformat(),
"files": agent_thought.files,
}
)
return result
@classmethod
def list_agent_providers(cls, user_id: str, tenant_id: str):
"""
List agent providers
"""
manager = PluginAgentClient()
return manager.fetch_agent_strategy_providers(tenant_id)
@classmethod
def get_agent_provider(cls, user_id: str, tenant_id: str, provider_name: str):
"""
Get agent provider
"""
manager = PluginAgentClient()
try:
return manager.fetch_agent_strategy_provider(tenant_id, provider_name)
except PluginDaemonClientSideError as e:
raise ValueError(str(e)) from e

View File

@@ -0,0 +1,558 @@
import uuid
import pandas as pd
from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
from services.feature_service import FeatureService
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
from tasks.annotation.delete_annotation_index_task import delete_annotation_index_task
from tasks.annotation.disable_annotation_reply_task import disable_annotation_reply_task
from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
answer = args.get("answer") or args.get("content")
if answer is None:
raise ValueError("Either 'answer' or 'content' must be provided")
if args.get("message_id"):
message_id = str(args["message_id"])
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app.id).first()
if not message:
raise NotFound("Message Not Exists.")
question = args.get("question") or message.query or ""
annotation: MessageAnnotation | None = message.annotation
if annotation:
annotation.content = answer
annotation.question = question
else:
annotation = MessageAnnotation(
app_id=app.id,
conversation_id=message.conversation_id,
message_id=message.id,
content=answer,
question=question,
account_id=current_user.id,
)
else:
question = args.get("question")
if not question:
raise ValueError("'question' is required when 'message_id' is not provided")
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id)
db.session.add(annotation)
db.session.commit()
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
assert current_tenant_id is not None
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
annotation.question,
current_tenant_id,
app_id,
annotation_setting.collection_binding_id,
)
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str):
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
return {"job_id": cache_result, "job_status": "processing"}
# async job
job_id = str(uuid.uuid4())
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
# send batch add segments task
redis_client.setnx(enable_app_annotation_job_key, "waiting")
current_user, current_tenant_id = current_account_with_tenant()
enable_annotation_reply_task.delay(
str(job_id),
app_id,
current_user.id,
current_tenant_id,
args["score_threshold"],
args["embedding_provider_name"],
args["embedding_model_name"],
)
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def disable_app_annotation(cls, app_id: str):
_, current_tenant_id = current_account_with_tenant()
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(disable_app_annotation_key)
if cache_result is not None:
return {"job_id": cache_result, "job_status": "processing"}
# async job
job_id = str(uuid.uuid4())
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
# send batch add segments task
redis_client.setnx(disable_app_annotation_job_key, "waiting")
disable_annotation_reply_task.delay(str(job_id), app_id, current_tenant_id)
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
if keyword:
stmt = (
select(MessageAnnotation)
.where(MessageAnnotation.app_id == app_id)
.where(
or_(
MessageAnnotation.question.ilike(f"%{keyword}%"),
MessageAnnotation.content.ilike(f"%{keyword}%"),
)
)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
)
else:
stmt = (
select(MessageAnnotation)
.where(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
)
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
return annotations.items, annotations.total
@classmethod
def export_annotation_list_by_app_id(cls, app_id: str):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotations = (
db.session.query(MessageAnnotation)
.where(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc())
.all()
)
return annotations
@classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation = MessageAnnotation(
app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
args["question"],
current_tenant_id,
app_id,
annotation_setting.collection_binding_id,
)
return annotation
@classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
annotation.content = args["answer"]
annotation.question = args["question"]
db.session.commit()
# if annotation reply is enabled , add annotation to index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if app_annotation_setting:
update_annotation_to_index_task.delay(
annotation.id,
annotation.question,
current_tenant_id,
app_id,
app_annotation_setting.collection_binding_id,
)
return annotation
@classmethod
def delete_app_annotation(cls, app_id: str, annotation_id: str):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
db.session.delete(annotation)
annotation_hit_histories = db.session.scalars(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id)
).all()
if annotation_hit_histories:
for annotation_hit_history in annotation_hit_histories:
db.session.delete(annotation_hit_history)
db.session.commit()
# if annotation reply is enabled , delete annotation index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if app_annotation_setting:
delete_annotation_index_task.delay(
annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id
)
@classmethod
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
# Fetch annotations and their settings in a single query
annotations_to_delete = (
db.session.query(MessageAnnotation, AppAnnotationSetting)
.outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id)
.where(MessageAnnotation.id.in_(annotation_ids))
.all()
)
if not annotations_to_delete:
return {"deleted_count": 0}
# Step 1: Extract IDs for bulk operations
annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete]
# Step 2: Bulk delete hit histories in a single query
db.session.query(AppAnnotationHitHistory).where(
AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete)
).delete(synchronize_session=False)
# Step 3: Trigger async tasks for search index deletion
for annotation, annotation_setting in annotations_to_delete:
if annotation_setting:
delete_annotation_index_task.delay(
annotation.id, app_id, current_tenant_id, annotation_setting.collection_binding_id
)
# Step 4: Bulk delete annotations in a single query
deleted_count = (
db.session.query(MessageAnnotation)
.where(MessageAnnotation.id.in_(annotation_ids_to_delete))
.delete(synchronize_session=False)
)
db.session.commit()
return {"deleted_count": deleted_count}
@classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage):
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
try:
# Skip the first row
df = pd.read_csv(file.stream, dtype=str)
result = []
for _, row in df.iterrows():
content = {"question": row.iloc[0], "answer": row.iloc[1]}
result.append(content)
if len(result) == 0:
raise ValueError("The CSV file is empty.")
# check annotation limit
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
annotation_quota_limit = features.annotation_quota_limit
if annotation_quota_limit.limit < len(result) + annotation_quota_limit.size:
raise ValueError("The number of annotations exceeds the limit of your subscription.")
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
# send batch add segments task
redis_client.setnx(indexing_cache_key, "waiting")
batch_import_annotations_task.delay(str(job_id), result, app_id, current_tenant_id, current_user.id)
except Exception as e:
return {"error_msg": str(e)}
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
_, current_tenant_id = current_account_with_tenant()
# get app info
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
stmt = (
select(AppAnnotationHitHistory)
.where(
AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id,
)
.order_by(AppAnnotationHitHistory.created_at.desc())
)
annotation_hit_histories = db.paginate(
select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False
)
return annotation_hit_histories.items, annotation_hit_histories.total
@classmethod
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
if not annotation:
return None
return annotation
@classmethod
def add_annotation_history(
cls,
annotation_id: str,
app_id: str,
annotation_question: str,
annotation_content: str,
query: str,
user_id: str,
message_id: str,
from_source: str,
score: float,
):
# add hit count to annotation
db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).update(
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False
)
annotation_hit_history = AppAnnotationHitHistory(
annotation_id=annotation_id,
app_id=app_id,
account_id=user_id,
question=query,
source=from_source,
score=score,
message_id=message_id,
annotation_question=annotation_question,
annotation_content=annotation_content,
)
db.session.add(annotation_hit_history)
db.session.commit()
@classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str):
_, current_tenant_id = current_account_with_tenant()
# get app info
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
if collection_binding_detail:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
else:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {},
}
return {"enabled": False}
@classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
current_user, current_tenant_id = current_account_with_tenant()
# get app info
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation_setting = (
db.session.query(AppAnnotationSetting)
.where(
AppAnnotationSetting.app_id == app_id,
AppAnnotationSetting.id == annotation_setting_id,
)
.first()
)
if not annotation_setting:
raise NotFound("App annotation not found")
annotation_setting.score_threshold = args["score_threshold"]
annotation_setting.updated_user_id = current_user.id
annotation_setting.updated_at = naive_utc_now()
db.session.add(annotation_setting)
db.session.commit()
collection_binding_detail = annotation_setting.collection_binding_detail
if collection_binding_detail:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
else:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {},
}
@classmethod
def clear_all_annotations(cls, app_id: str):
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
# if annotation reply is enabled, delete annotation index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
annotations_query = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id)
for annotation in annotations_query.yield_per(100):
annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).where(
AppAnnotationHitHistory.annotation_id == annotation.id
)
for annotation_hit_history in annotation_hit_histories_query.yield_per(100):
db.session.delete(annotation_hit_history)
# if annotation reply is enabled, delete annotation index
if app_annotation_setting:
delete_annotation_index_task.delay(
annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id
)
db.session.delete(annotation)
db.session.commit()
return {"result": "success"}

View File

@@ -0,0 +1,105 @@
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.helper.encrypter import decrypt_token, encrypt_token
from extensions.ext_database import db
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
class APIBasedExtensionService:
@staticmethod
def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]:
extension_list = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=tenant_id)
.order_by(APIBasedExtension.created_at.desc())
.all()
)
for extension in extension_list:
extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
return extension_list
@classmethod
def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension:
cls._validation(extension_data)
extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key)
db.session.add(extension_data)
db.session.commit()
return extension_data
@staticmethod
def delete(extension_data: APIBasedExtension):
db.session.delete(extension_data)
db.session.commit()
@staticmethod
def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
extension = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=tenant_id)
.filter_by(id=api_based_extension_id)
.first()
)
if not extension:
raise ValueError("API based extension is not found")
extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
return extension
@classmethod
def _validation(cls, extension_data: APIBasedExtension):
# name
if not extension_data.name:
raise ValueError("name must not be empty")
if not extension_data.id:
# case one: check new data, name must be unique
is_name_existed = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=extension_data.tenant_id)
.filter_by(name=extension_data.name)
.first()
)
if is_name_existed:
raise ValueError("name must be unique, it is already existed")
else:
# case two: check existing data, name must be unique
is_name_existed = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=extension_data.tenant_id)
.filter_by(name=extension_data.name)
.where(APIBasedExtension.id != extension_data.id)
.first()
)
if is_name_existed:
raise ValueError("name must be unique, it is already existed")
# api_endpoint
if not extension_data.api_endpoint:
raise ValueError("api_endpoint must not be empty")
# api_key
if not extension_data.api_key:
raise ValueError("api_key must not be empty")
if len(extension_data.api_key) < 5:
raise ValueError("api_key must be at least 5 characters")
# check endpoint
cls._ping_connection(extension_data)
@staticmethod
def _ping_connection(extension_data: APIBasedExtension):
try:
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
resp = client.request(point=APIBasedExtensionPoint.PING, params={})
if resp.get("result") != "pong":
raise ValueError(resp)
except Exception as e:
raise ValueError(f"connection error: {e}")

View File

@@ -0,0 +1,844 @@
import base64
import hashlib
import logging
import uuid
from collections.abc import Mapping
from enum import StrEnum
from urllib.parse import urlparse
from uuid import uuid4
import yaml
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from packaging import version
from packaging.version import parse as parse_version
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.helper import ssrf_proxy
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import PluginDependency
from core.workflow.enums import NodeType
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode
from events.app_event import app_model_config_was_updated, app_was_created
from extensions.ext_redis import redis_client
from factories import variable_factory
from libs.datetime_utils import naive_utc_now
from models import Account, App, AppMode
from models.model import AppModelConfig
from models.workflow import Workflow
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.workflow_draft_variable_service import WorkflowDraftVariableService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.5.0"
class ImportMode(StrEnum):
YAML_CONTENT = "yaml-content"
YAML_URL = "yaml-url"
class ImportStatus(StrEnum):
COMPLETED = "completed"
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
PENDING = "pending"
FAILED = "failed"
class Import(BaseModel):
id: str
status: ImportStatus
app_id: str | None = None
app_mode: str | None = None
current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = ""
error: str = ""
class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
def _check_version_compatibility(imported_version: str) -> ImportStatus:
"""Determine import status based on version comparison"""
try:
current_ver = version.parse(CURRENT_DSL_VERSION)
imported_ver = version.parse(imported_version)
except version.InvalidVersion:
return ImportStatus.FAILED
# If imported version is newer than current, always return PENDING
if imported_ver > current_ver:
return ImportStatus.PENDING
# If imported version is older than current's major, return PENDING
if imported_ver.major < current_ver.major:
return ImportStatus.PENDING
# If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS
if imported_ver.minor < current_ver.minor:
return ImportStatus.COMPLETED_WITH_WARNINGS
# If imported version equals or is older than current's micro, return COMPLETED
return ImportStatus.COMPLETED
class PendingData(BaseModel):
import_mode: str
yaml_content: str
name: str | None = None
description: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
app_id: str | None = None
class CheckDependenciesPendingData(BaseModel):
dependencies: list[PluginDependency]
app_id: str | None = None
class AppDslService:
def __init__(self, session: Session):
self._session = session
def import_app(
self,
*,
account: Account,
import_mode: str,
yaml_content: str | None = None,
yaml_url: str | None = None,
name: str | None = None,
description: str | None = None,
icon_type: str | None = None,
icon: str | None = None,
icon_background: str | None = None,
app_id: str | None = None,
) -> Import:
"""Import an app from YAML content or URL."""
import_id = str(uuid.uuid4())
# Validate import mode
try:
mode = ImportMode(import_mode)
except ValueError:
raise ValueError(f"Invalid import_mode: {import_mode}")
# Get YAML content
content: str = ""
if mode == ImportMode.YAML_URL:
if not yaml_url:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="yaml_url is required when import_mode is yaml-url",
)
try:
parsed_url = urlparse(yaml_url)
if (
parsed_url.scheme == "https"
and parsed_url.netloc == "github.com"
and parsed_url.path.endswith((".yml", ".yaml"))
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/")
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
response.raise_for_status()
content = response.content.decode()
if len(content) > DSL_MAX_SIZE:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="File size exceeds the limit of 10MB",
)
if not content:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="Empty content from url",
)
except Exception as e:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error=f"Error fetching YAML from URL: {str(e)}",
)
elif mode == ImportMode.YAML_CONTENT:
if not yaml_content:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="yaml_content is required when import_mode is yaml-content",
)
content = yaml_content
# Process YAML content
try:
# Parse YAML to validate format
data = yaml.safe_load(content)
if not isinstance(data, dict):
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid YAML format: content must be a mapping",
)
# Validate and fix DSL version
if not data.get("version"):
data["version"] = "0.1.0"
if not data.get("kind") or data.get("kind") != "app":
data["kind"] = "app"
imported_version = data.get("version", "0.1.0")
# check if imported_version is a float-like string
if not isinstance(imported_version, str):
raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}")
status = _check_version_compatibility(imported_version)
# Extract app data
app_data = data.get("app")
if not app_data:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="Missing app data in YAML content",
)
# If app_id is provided, check if it exists
app = None
if app_id:
stmt = select(App).where(App.id == app_id, App.tenant_id == account.current_tenant_id)
app = self._session.scalar(stmt)
if not app:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="App not found",
)
if app.mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="Only workflow or advanced chat apps can be overwritten",
)
# If major version mismatch, store import info in Redis
if status == ImportStatus.PENDING:
pending_data = PendingData(
import_mode=import_mode,
yaml_content=content,
name=name,
description=description,
icon_type=icon_type,
icon=icon,
icon_background=icon_background,
app_id=app_id,
)
redis_client.setex(
f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}",
IMPORT_INFO_REDIS_EXPIRY,
pending_data.model_dump_json(),
)
return Import(
id=import_id,
status=status,
app_id=app_id,
imported_dsl_version=imported_version,
)
# Extract dependencies
dependencies = data.get("dependencies", [])
check_dependencies_pending_data = None
if dependencies:
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
elif parse_version(imported_version) <= parse_version("0.1.5"):
if "workflow" in data:
graph = data.get("workflow", {}).get("graph", {})
dependencies_list = self._extract_dependencies_from_workflow_graph(graph)
else:
dependencies_list = self._extract_dependencies_from_model_config(data.get("model_config", {}))
check_dependencies_pending_data = DependenciesAnalysisService.generate_latest_dependencies(
dependencies_list
)
# Create or update app
app = self._create_or_update_app(
app=app,
data=data,
account=account,
name=name,
description=description,
icon_type=icon_type,
icon=icon,
icon_background=icon_background,
dependencies=check_dependencies_pending_data,
)
draft_var_srv = WorkflowDraftVariableService(session=self._session)
draft_var_srv.delete_workflow_variables(app_id=app.id)
return Import(
id=import_id,
status=status,
app_id=app.id,
app_mode=app.mode,
imported_dsl_version=imported_version,
)
except yaml.YAMLError as e:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error=f"Invalid YAML format: {str(e)}",
)
except Exception as e:
logger.exception("Failed to import app")
return Import(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)
def confirm_import(self, *, import_id: str, account: Account) -> Import:
"""
Confirm an import that requires confirmation
"""
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
pending_data = redis_client.get(redis_key)
if not pending_data:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="Import information expired or does not exist",
)
try:
if not isinstance(pending_data, str | bytes):
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid import information",
)
pending_data = PendingData.model_validate_json(pending_data)
data = yaml.safe_load(pending_data.yaml_content)
app = None
if pending_data.app_id:
stmt = select(App).where(App.id == pending_data.app_id, App.tenant_id == account.current_tenant_id)
app = self._session.scalar(stmt)
# Create or update app
app = self._create_or_update_app(
app=app,
data=data,
account=account,
name=pending_data.name,
description=pending_data.description,
icon_type=pending_data.icon_type,
icon=pending_data.icon,
icon_background=pending_data.icon_background,
)
# Delete import info from Redis
redis_client.delete(redis_key)
return Import(
id=import_id,
status=ImportStatus.COMPLETED,
app_id=app.id,
app_mode=app.mode,
current_dsl_version=CURRENT_DSL_VERSION,
imported_dsl_version=data.get("version", "0.1.0"),
)
except Exception as e:
logger.exception("Error confirming import")
return Import(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)
def check_dependencies(
self,
*,
app_model: App,
) -> CheckDependenciesResult:
"""Check dependencies"""
# Get dependencies from Redis
redis_key = f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app_model.id}"
dependencies = redis_client.get(redis_key)
if not dependencies:
return CheckDependenciesResult()
# Extract dependencies
dependencies = CheckDependenciesPendingData.model_validate_json(dependencies)
# Get leaked dependencies
leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies(
tenant_id=app_model.tenant_id, dependencies=dependencies.dependencies
)
return CheckDependenciesResult(
leaked_dependencies=leaked_dependencies,
)
def _create_or_update_app(
self,
*,
app: App | None,
data: dict,
account: Account,
name: str | None = None,
description: str | None = None,
icon_type: str | None = None,
icon: str | None = None,
icon_background: str | None = None,
dependencies: list[PluginDependency] | None = None,
) -> App:
"""Create a new app or update an existing one."""
app_data = data.get("app", {})
app_mode = app_data.get("mode")
if not app_mode:
raise ValueError("loss app mode")
app_mode = AppMode(app_mode)
# Set icon type
icon_type_value = icon_type or app_data.get("icon_type")
if icon_type_value in ["emoji", "link", "image"]:
icon_type = icon_type_value
else:
icon_type = "emoji"
icon = icon or str(app_data.get("icon", ""))
if app:
# Update existing app
app.name = name or app_data.get("name", app.name)
app.description = description or app_data.get("description", app.description)
app.icon_type = icon_type
app.icon = icon
app.icon_background = icon_background or app_data.get("icon_background", app.icon_background)
app.updated_by = account.id
app.updated_at = naive_utc_now()
else:
if account.current_tenant_id is None:
raise ValueError("Current tenant is not set")
# Create new app
app = App()
app.id = str(uuid4())
app.tenant_id = account.current_tenant_id
app.mode = app_mode.value
app.name = name or app_data.get("name", "")
app.description = description or app_data.get("description", "")
app.icon_type = icon_type
app.icon = icon
app.icon_background = icon_background or app_data.get("icon_background", "#FFFFFF")
app.enable_site = True
app.enable_api = True
app.use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False)
app.created_by = account.id
app.updated_by = account.id
self._session.add(app)
self._session.commit()
app_was_created.send(app, account=account)
# save dependencies
if dependencies:
redis_client.setex(
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}",
IMPORT_INFO_REDIS_EXPIRY,
CheckDependenciesPendingData(app_id=app.id, dependencies=dependencies).model_dump_json(),
)
# Initialize app based on mode
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
raise ValueError("Missing workflow data for workflow/advanced chat app")
environment_variables_list = workflow_data.get("environment_variables", [])
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = workflow_data.get("conversation_variables", [])
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
workflow_service = WorkflowService()
current_draft_workflow = workflow_service.get_draft_workflow(app_model=app)
if current_draft_workflow:
unique_hash = current_draft_workflow.unique_hash
else:
unique_hash = None
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
decrypted_id
for dataset_id in dataset_ids
if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id))
]
workflow_service.sync_draft_workflow(
app_model=app,
graph=workflow_data.get("graph", {}),
features=workflow_data.get("features", {}),
unique_hash=unique_hash,
account=account,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}:
# Initialize model config
model_config = data.get("model_config")
if not model_config or not isinstance(model_config, dict):
raise ValueError("Missing model_config for chat/agent-chat/completion app")
# Initialize or update model config
if not app.app_model_config:
app_model_config = AppModelConfig().from_model_config_dict(model_config)
app_model_config.id = str(uuid4())
app_model_config.app_id = app.id
app_model_config.created_by = account.id
app_model_config.updated_by = account.id
app.app_model_config_id = app_model_config.id
self._session.add(app_model_config)
app_model_config_was_updated.send(app, app_model_config=app_model_config)
else:
raise ValueError("Invalid app mode")
return app
@classmethod
def export_dsl(cls, app_model: App, include_secret: bool = False, workflow_id: str | None = None) -> str:
"""
Export app
:param app_model: App instance
:param include_secret: Whether include secret variable
:return:
"""
app_mode = AppMode.value_of(app_model.mode)
export_data = {
"version": CURRENT_DSL_VERSION,
"kind": "app",
"app": {
"name": app_model.name,
"mode": app_model.mode,
"icon": app_model.icon if app_model.icon_type == "image" else "🤖",
"icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background,
"description": app_model.description,
"use_icon_as_answer_icon": app_model.use_icon_as_answer_icon,
},
}
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
cls._append_workflow_export_data(
export_data=export_data, app_model=app_model, include_secret=include_secret, workflow_id=workflow_id
)
else:
cls._append_model_config_export_data(export_data, app_model)
return yaml.dump(export_data, allow_unicode=True)
@classmethod
def _append_workflow_export_data(
cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: str | None = None
):
"""
Append workflow export data
:param export_data: export data
:param app_model: App instance
"""
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model, workflow_id)
if not workflow:
raise ValueError("Missing draft workflow configuration, please check.")
workflow_dict = workflow.to_dict(include_secret=include_secret)
# TODO: refactor: we need a better way to filter workspace related data from nodes
for node in workflow_dict.get("graph", {}).get("nodes", []):
node_data = node.get("data", {})
if not node_data:
continue
data_type = node_data.get("type", "")
if data_type == NodeType.KNOWLEDGE_RETRIEVAL:
dataset_ids = node_data.get("dataset_ids", [])
node_data["dataset_ids"] = [
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
for dataset_id in dataset_ids
]
# filter credential id from tool node
if not include_secret and data_type == NodeType.TOOL:
node_data.pop("credential_id", None)
# filter credential id from agent node
if not include_secret and data_type == NodeType.AGENT:
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
tool.pop("credential_id", None)
if data_type == NodeType.TRIGGER_SCHEDULE.value:
# override the config with the default config
node_data["config"] = TriggerScheduleNode.get_default_config()["config"]
if data_type == NodeType.TRIGGER_WEBHOOK.value:
# clear the webhook_url
node_data["webhook_url"] = ""
node_data["webhook_debug_url"] = ""
if data_type == NodeType.TRIGGER_PLUGIN.value:
# clear the subscription_id
node_data["subscription_id"] = ""
export_data["workflow"] = workflow_dict
dependencies = cls._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
for d in DependenciesAnalysisService.generate_dependencies(
tenant_id=app_model.tenant_id, dependencies=dependencies
)
]
@classmethod
def _append_model_config_export_data(cls, export_data: dict, app_model: App):
"""
Append model config export data
:param export_data: export data
:param app_model: App instance
"""
app_model_config = app_model.app_model_config
if not app_model_config:
raise ValueError("Missing app configuration, please check.")
model_config = app_model_config.to_dict()
# TODO: refactor: we need a better way to filter workspace related data from model config
# filter credential id from model config
for tool in model_config.get("agent_mode", {}).get("tools", []):
tool.pop("credential_id", None)
export_data["model_config"] = model_config
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
for d in DependenciesAnalysisService.generate_dependencies(
tenant_id=app_model.tenant_id, dependencies=dependencies
)
]
@classmethod
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
"""
Extract dependencies from workflow
:param workflow: Workflow instance
:return: dependencies list format like ["langgenius/google"]
"""
graph = workflow.graph_dict
dependencies = cls._extract_dependencies_from_workflow_graph(graph)
return dependencies
@classmethod
def _extract_dependencies_from_workflow_graph(cls, graph: Mapping) -> list[str]:
"""
Extract dependencies from workflow graph
:param graph: Workflow graph
:return: dependencies list format like ["langgenius/google"]
"""
dependencies = []
for node in graph.get("nodes", []):
try:
typ = node.get("data", {}).get("type")
match typ:
case NodeType.TOOL:
tool_entity = ToolNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
)
case NodeType.LLM:
llm_entity = LLMNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
)
case NodeType.QUESTION_CLASSIFIER:
question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
question_classifier_entity.model.provider
),
)
case NodeType.PARAMETER_EXTRACTOR:
parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
parameter_extractor_entity.model.provider
),
)
case NodeType.KNOWLEDGE_RETRIEVAL:
knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"])
if knowledge_retrieval_entity.retrieval_mode == "multiple":
if knowledge_retrieval_entity.multiple_retrieval_config:
if (
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
== "reranking_model"
):
if knowledge_retrieval_entity.multiple_retrieval_config.reranking_model:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
knowledge_retrieval_entity.multiple_retrieval_config.reranking_model.provider
),
)
elif (
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
== "weighted_score"
):
if knowledge_retrieval_entity.multiple_retrieval_config.weights:
vector_setting = (
knowledge_retrieval_entity.multiple_retrieval_config.weights.vector_setting
)
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
vector_setting.embedding_provider_name
),
)
elif knowledge_retrieval_entity.retrieval_mode == "single":
model_config = knowledge_retrieval_entity.single_retrieval_config
if model_config:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
model_config.model.provider
),
)
case _:
# TODO: Handle default case or unknown node types
pass
except Exception as e:
logger.exception("Error extracting node dependency", exc_info=e)
return dependencies
@classmethod
def _extract_dependencies_from_model_config(cls, model_config: Mapping) -> list[str]:
"""
Extract dependencies from model config
:param model_config: model config dict
:return: dependencies list format like ["langgenius/google"]
"""
dependencies = []
try:
# completion model
model_dict = model_config.get("model", {})
if model_dict:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(model_dict.get("provider", ""))
)
# reranking model
dataset_configs = model_config.get("dataset_configs", {})
if dataset_configs:
for dataset_config in dataset_configs.get("datasets", {}).get("datasets", []):
if dataset_config.get("reranking_model"):
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
dataset_config.get("reranking_model", {})
.get("reranking_provider_name", {})
.get("provider")
)
)
# tools
agent_configs = model_config.get("agent_mode", {})
if agent_configs:
for agent_config in agent_configs.get("tools", []):
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(agent_config.get("provider_id"))
)
except Exception as e:
logger.exception("Error extracting model config dependency", exc_info=e)
return dependencies
@classmethod
def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]:
"""
Returns the leaked dependencies in current workspace
"""
dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
if not dependencies:
return []
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
@staticmethod
def _generate_aes_key(tenant_id: str) -> bytes:
"""Generate AES key based on tenant_id"""
return hashlib.sha256(tenant_id.encode()).digest()
@classmethod
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str:
"""Encrypt dataset_id using AES-CBC mode or return plain text based on configuration"""
if not dify_config.DSL_EXPORT_ENCRYPT_DATASET_ID:
return dataset_id
key = cls._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size))
return base64.b64encode(ct_bytes).decode()
@classmethod
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None:
"""AES decryption with fallback to plain text UUID"""
# First, check if it's already a plain UUID (not encrypted)
if cls._is_valid_uuid(encrypted_data):
return encrypted_data
# If it's not a UUID, try to decrypt it
try:
key = cls._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
decrypted_text = pt.decode()
# Validate that the decrypted result is a valid UUID
if cls._is_valid_uuid(decrypted_text):
return decrypted_text
else:
# If decrypted result is not a valid UUID, it's probably not our encrypted data
return None
except Exception:
# If decryption fails completely, return None
return None
@staticmethod
def _is_valid_uuid(value: str) -> bool:
"""Check if string is a valid UUID format"""
try:
uuid.UUID(value)
return True
except (ValueError, TypeError):
return False

View File

@@ -0,0 +1,240 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Union
from configs import dify_config
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.chat.app_generator import ChatAppGenerator
from core.app.apps.completion.app_generator import CompletionAppGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting import RateLimit
from enums.quota_type import QuotaType, unlimited
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.workflow_service import WorkflowService
class AppGenerateService:
@classmethod
def generate(
cls,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
root_node_id: str | None = None,
):
"""
App Content Generate
:param app_model: app model
:param user: user
:param args: args
:param invoke_from: invoke from
:param streaming: streaming
:return:
"""
quota_charge = unlimited()
if dify_config.BILLING_ENABLED:
try:
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
except QuotaExceededError:
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
# app level rate limiter
max_active_request = cls._get_max_active_requests(app_model)
rate_limit = RateLimit(app_model.id, max_active_request)
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
if app_model.mode == AppMode.COMPLETION:
return rate_limit.generate(
CompletionAppGenerator.convert_to_event_stream(
CompletionAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id=request_id,
)
elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
return rate_limit.generate(
AgentChatAppGenerator.convert_to_event_stream(
AgentChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id,
)
elif app_model.mode == AppMode.CHAT:
return rate_limit.generate(
ChatAppGenerator.convert_to_event_stream(
ChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id=request_id,
)
elif app_model.mode == AppMode.ADVANCED_CHAT:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
),
),
request_id=request_id,
)
elif app_model.mode == AppMode.WORKFLOW:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
root_node_id=root_node_id,
call_depth=0,
),
),
request_id,
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
except Exception:
quota_charge.refund()
rate_limit.exit(request_id)
raise
finally:
if not streaming:
rate_limit.exit(request_id)
@staticmethod
def _get_max_active_requests(app: App) -> int:
"""
Get the maximum number of active requests allowed for an app.
Returns the smaller value between app's custom limit and global config limit.
A value of 0 means infinite (no limit).
Args:
app: The App model instance
Returns:
The maximum number of active requests allowed
"""
app_limit = app.max_active_requests or 0
config_limit = dify_config.APP_MAX_ACTIVE_REQUESTS
# Filter out infinite (0) values and return the minimum, or 0 if both are infinite
limits = [limit for limit in [app_limit, config_limit] if limit > 0]
return min(limits) if limits else 0
@classmethod
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
@classmethod
def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_loop_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
@classmethod
def generate_more_like_this(
cls,
app_model: App,
user: Union[Account, EndUser],
message_id: str,
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping, Generator]:
"""
Generate more like this
:param app_model: app model
:param user: user
:param message_id: message id
:param invoke_from: invoke from
:param streaming: streaming
:return:
"""
return CompletionAppGenerator().generate_more_like_this(
app_model=app_model, message_id=message_id, user=user, invoke_from=invoke_from, stream=streaming
)
@classmethod
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: str | None = None) -> Workflow:
"""
Get workflow
:param app_model: app model
:param invoke_from: invoke from
:param workflow_id: optional workflow id to specify a specific version
:return:
"""
workflow_service = WorkflowService()
# If workflow_id is specified, get the specific workflow version
if workflow_id:
try:
_ = uuid.UUID(workflow_id)
except ValueError:
raise WorkflowIdFormatError(f"Invalid workflow_id format: '{workflow_id}'. ")
workflow = workflow_service.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id)
if not workflow:
raise WorkflowNotFoundError(f"Workflow not found with id: {workflow_id}")
return workflow
if invoke_from == InvokeFrom.DEBUGGER:
# fetch draft workflow by app_model
workflow = workflow_service.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("Workflow not initialized")
else:
# fetch published workflow by app_model
workflow = workflow_service.get_published_workflow(app_model=app_model)
if not workflow:
raise ValueError("Workflow not published")
return workflow

View File

@@ -0,0 +1,17 @@
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from models.model import AppMode
class AppModelConfigService:
@classmethod
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
if app_mode == AppMode.CHAT:
return ChatAppConfigManager.config_validate(tenant_id, config)
elif app_mode == AppMode.AGENT_CHAT:
return AgentChatAppConfigManager.config_validate(tenant_id, config)
elif app_mode == AppMode.COMPLETION:
return CompletionAppConfigManager.config_validate(tenant_id, config)
else:
raise ValueError(f"Invalid app mode: {app_mode}")

View File

@@ -0,0 +1,437 @@
import json
import logging
from typing import TypedDict, cast
import sqlalchemy as sa
from flask_sqlalchemy.pagination import Pagination
from configs import dify_config
from constants.model_template import default_app_templates
from core.agent.entities import AgentToolEntity
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account
from models.model import App, AppMode, AppModelConfig, Site
from models.tools import ApiToolProvider
from services.billing_service import BillingService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.tag_service import TagService
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
logger = logging.getLogger(__name__)
class AppService:
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
"""
Get app list with pagination
:param user_id: user id
:param tenant_id: tenant id
:param args: request args
:return:
"""
filters = [App.tenant_id == tenant_id, App.is_universal == False]
if args["mode"] == "workflow":
filters.append(App.mode == AppMode.WORKFLOW)
elif args["mode"] == "completion":
filters.append(App.mode == AppMode.COMPLETION)
elif args["mode"] == "chat":
filters.append(App.mode == AppMode.CHAT)
elif args["mode"] == "advanced-chat":
filters.append(App.mode == AppMode.ADVANCED_CHAT)
elif args["mode"] == "agent-chat":
filters.append(App.mode == AppMode.AGENT_CHAT)
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
if args.get("name"):
name = args["name"][:30]
filters.append(App.name.ilike(f"%{name}%"))
# Check if tag_ids is not empty to avoid WHERE false condition
if args.get("tag_ids") and len(args["tag_ids"]) > 0:
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])
if target_ids and len(target_ids) > 0:
filters.append(App.id.in_(target_ids))
else:
return None
app_models = db.paginate(
sa.select(App).where(*filters).order_by(App.created_at.desc()),
page=args["page"],
per_page=args["limit"],
error_out=False,
)
return app_models
def create_app(self, tenant_id: str, args: dict, account: Account) -> App:
"""
Create app
:param tenant_id: tenant id
:param args: request args
:param account: Account instance
"""
app_mode = AppMode.value_of(args["mode"])
app_template = default_app_templates[app_mode]
# get model config
default_model_config = app_template.get("model_config")
default_model_config = default_model_config.copy() if default_model_config else None
if default_model_config and "model" in default_model_config:
# get model provider
model_manager = ModelManager()
# get default model instance
try:
model_instance = model_manager.get_default_model_instance(
tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
)
except (ProviderTokenNotInitError, LLMBadRequestError):
model_instance = None
except Exception:
logger.exception("Get default model instance failed, tenant_id: %s", tenant_id)
model_instance = None
if model_instance:
if (
model_instance.model == default_model_config["model"]["name"]
and model_instance.provider == default_model_config["model"]["provider"]
):
default_model_dict = default_model_config["model"]
else:
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if model_schema is None:
raise ValueError(f"model schema not found for model {model_instance.model}")
default_model_dict = {
"provider": model_instance.provider,
"name": model_instance.model,
"mode": model_schema.model_properties.get(ModelPropertyKey.MODE),
"completion_params": {},
}
else:
provider, model = model_manager.get_default_provider_model_name(
tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
)
default_model_config["model"]["provider"] = provider
default_model_config["model"]["name"] = model
default_model_dict = default_model_config["model"]
default_model_config["model"] = json.dumps(default_model_dict)
app = App(**app_template["app"])
app.name = args["name"]
app.description = args.get("description", "")
app.mode = args["mode"]
app.icon_type = args.get("icon_type", "emoji")
app.icon = args["icon"]
app.icon_background = args["icon_background"]
app.tenant_id = tenant_id
app.api_rph = args.get("api_rph", 0)
app.api_rpm = args.get("api_rpm", 0)
app.created_by = account.id
app.updated_by = account.id
db.session.add(app)
db.session.flush()
if default_model_config:
app_model_config = AppModelConfig(**default_model_config)
app_model_config.app_id = app.id
app_model_config.created_by = account.id
app_model_config.updated_by = account.id
db.session.add(app_model_config)
db.session.flush()
app.app_model_config_id = app_model_config.id
db.session.commit()
app_was_created.send(app, account=account)
if FeatureService.get_system_features().webapp_auth.enabled:
# update web app setting as private
EnterpriseService.WebAppAuth.update_app_access_mode(app.id, "private")
if dify_config.BILLING_ENABLED:
BillingService.clean_billing_info_cache(app.tenant_id)
return app
def get_app(self, app: App) -> App:
"""
Get App
"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get original app model config
if app.mode == AppMode.AGENT_CHAT or app.is_agent:
model_config = app.app_model_config
if not model_config:
return app
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get("tools") or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
app_id=app.id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f"AGENT.{app.id}",
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}
# override tool parameters
tool["tool_parameters"] = masked_parameter
except Exception:
pass
# override agent mode
if model_config:
model_config.agent_mode = json.dumps(agent_mode)
class ModifiedApp(App):
"""
Modified App class
"""
def __init__(self, app):
self.__dict__.update(app.__dict__)
@property
def app_model_config(self):
return model_config
app = ModifiedApp(app)
return app
class ArgsDict(TypedDict):
name: str
description: str
icon_type: str
icon: str
icon_background: str
use_icon_as_answer_icon: bool
max_active_requests: int
def update_app(self, app: App, args: ArgsDict) -> App:
"""
Update app
:param app: App instance
:param args: request args
:return: App instance
"""
assert current_user is not None
app.name = args["name"]
app.description = args["description"]
app.icon_type = args["icon_type"]
app.icon = args["icon"]
app.icon_background = args["icon_background"]
app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False)
app.max_active_requests = args.get("max_active_requests")
app.updated_by = current_user.id
app.updated_at = naive_utc_now()
db.session.commit()
return app
def update_app_name(self, app: App, name: str) -> App:
"""
Update app name
:param app: App instance
:param name: new name
:return: App instance
"""
assert current_user is not None
app.name = name
app.updated_by = current_user.id
app.updated_at = naive_utc_now()
db.session.commit()
return app
def update_app_icon(self, app: App, icon: str, icon_background: str) -> App:
"""
Update app icon
:param app: App instance
:param icon: new icon
:param icon_background: new icon_background
:return: App instance
"""
assert current_user is not None
app.icon = icon
app.icon_background = icon_background
app.updated_by = current_user.id
app.updated_at = naive_utc_now()
db.session.commit()
return app
def update_app_site_status(self, app: App, enable_site: bool) -> App:
"""
Update app site status
:param app: App instance
:param enable_site: enable site status
:return: App instance
"""
if enable_site == app.enable_site:
return app
assert current_user is not None
app.enable_site = enable_site
app.updated_by = current_user.id
app.updated_at = naive_utc_now()
db.session.commit()
return app
def update_app_api_status(self, app: App, enable_api: bool) -> App:
"""
Update app api status
:param app: App instance
:param enable_api: enable api status
:return: App instance
"""
if enable_api == app.enable_api:
return app
assert current_user is not None
app.enable_api = enable_api
app.updated_by = current_user.id
app.updated_at = naive_utc_now()
db.session.commit()
return app
def delete_app(self, app: App):
"""
Delete app
:param app: App instance
"""
db.session.delete(app)
db.session.commit()
# clean up web app settings
if FeatureService.get_system_features().webapp_auth.enabled:
EnterpriseService.WebAppAuth.cleanup_webapp(app.id)
if dify_config.BILLING_ENABLED:
BillingService.clean_billing_info_cache(app.tenant_id)
# Trigger asynchronous deletion of app and related data
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
def get_app_meta(self, app_model: App):
"""
Get app meta info
:param app_model: app model
:return:
"""
app_mode = AppMode.value_of(app_model.mode)
meta: dict = {"tool_icons": {}}
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
return meta
graph = workflow.graph_dict
nodes = graph.get("nodes", [])
tools = []
for node in nodes:
if node.get("data", {}).get("type") == "tool":
node_data = node.get("data", {})
tools.append(
{
"provider_type": node_data.get("provider_type"),
"provider_id": node_data.get("provider_id"),
"tool_name": node_data.get("tool_name"),
"tool_parameters": {},
}
)
else:
app_model_config: AppModelConfig | None = app_model.app_model_config
if not app_model_config:
return meta
agent_config = app_model_config.agent_mode_dict
# get all tools
tools = agent_config.get("tools", [])
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
for tool in tools:
keys = list(tool.keys())
if len(keys) >= 4:
# current tool standard
provider_type = tool.get("provider_type", "")
provider_id = tool.get("provider_id", "")
tool_name = tool.get("tool_name", "")
if provider_type == "builtin":
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
elif provider_type == "api":
try:
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first()
)
if provider is None:
raise ValueError(f"provider not found for tool {tool_name}")
meta["tool_icons"][tool_name] = json.loads(provider.icon)
except:
meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"}
return meta
@staticmethod
def get_app_code_by_id(app_id: str) -> str:
"""
Get app code by app id
:param app_id: app id
:return: app code
"""
site = db.session.query(Site).where(Site.app_id == app_id).first()
if not site:
raise ValueError(f"App with id {app_id} not found")
return str(site.code)
@staticmethod
def get_app_id_by_code(app_code: str) -> str:
"""
Get app id by app code
:param app_code: app code
:return: app id
"""
site = db.session.query(Site).where(Site.code == app_code).first()
if not site:
raise ValueError(f"App with code {app_code} not found")
return str(site.app_id)

View File

@@ -0,0 +1,45 @@
"""Service for managing application task operations.
This service provides centralized logic for task control operations
like stopping tasks, handling both legacy Redis flag mechanism and
new GraphEngine command channel mechanism.
"""
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.graph_engine.manager import GraphEngineManager
from models.model import AppMode
class AppTaskService:
"""Service for managing application task operations."""
@staticmethod
def stop_task(
task_id: str,
invoke_from: InvokeFrom,
user_id: str,
app_mode: AppMode,
) -> None:
"""Stop a running task.
This method handles stopping tasks using both mechanisms:
1. Legacy Redis flag mechanism (for backward compatibility)
2. New GraphEngine command channel (for workflow-based apps)
Args:
task_id: The task ID to stop
invoke_from: The source of the invoke (e.g., DEBUGGER, WEB_APP, SERVICE_API)
user_id: The user ID requesting the stop
app_mode: The application mode (CHAT, AGENT_CHAT, ADVANCED_CHAT, WORKFLOW, etc.)
Returns:
None
"""
# Legacy mechanism: Set stop flag in Redis
AppQueueManager.set_stop_flag(task_id, invoke_from, user_id)
# New mechanism: Send stop command via GraphEngine for workflow-based apps
# This ensures proper workflow status recording in the persistence layer
if app_mode in (AppMode.ADVANCED_CHAT, AppMode.WORKFLOW):
GraphEngineManager.send_stop_command(task_id)

View File

@@ -0,0 +1,321 @@
"""
Universal async workflow execution service.
This service provides a centralized entry point for triggering workflows asynchronously
with support for different subscription tiers, rate limiting, and execution tracking.
"""
import json
from datetime import UTC, datetime
from typing import Any, Union
from celery.result import AsyncResult
from sqlalchemy import select
from sqlalchemy.orm import Session
from enums.quota_type import QuotaType
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatorUserRole, WorkflowTriggerStatus
from models.model import App, EndUser
from models.trigger import WorkflowTriggerLog
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowNotFoundError
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
from tasks.async_workflow_tasks import (
execute_workflow_professional,
execute_workflow_sandbox,
execute_workflow_team,
)
class AsyncWorkflowService:
"""
Universal entry point for async workflow execution - ALL METHODS ARE NON-BLOCKING
This service handles:
- Trigger data validation and processing
- Queue routing based on subscription tier
- Daily rate limiting with timezone support
- Execution tracking and logging
- Retry mechanisms for failed executions
Important: All trigger methods return immediately after queuing tasks.
Actual workflow execution happens asynchronously in background Celery workers.
Use trigger log IDs to monitor execution status and results.
"""
@classmethod
def trigger_workflow_async(
cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData
) -> AsyncTriggerResponse:
"""
Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK
Creates a trigger log and dispatches to appropriate queue based on subscription tier.
The workflow execution happens asynchronously in the background via Celery workers.
This method returns immediately after queuing the task, not after execution completion.
Args:
session: Database session to use for operations
user: User (Account or EndUser) who initiated the workflow trigger
trigger_data: Validated Pydantic model containing trigger information
Returns:
AsyncTriggerResponse with workflow_trigger_log_id, task_id, status="queued", and queue
Note: The actual workflow execution status must be checked separately via workflow_trigger_log_id
Raises:
WorkflowNotFoundError: If app or workflow not found
InvokeDailyRateLimitError: If daily rate limit exceeded
Behavior:
- Non-blocking: Returns immediately after queuing
- Asynchronous: Actual execution happens in background Celery workers
- Status tracking: Use workflow_trigger_log_id to monitor progress
- Queue-based: Routes to different queues based on subscription tier
"""
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
dispatcher_manager = QueueDispatcherManager()
workflow_service = WorkflowService()
# 1. Validate app exists
app_model = session.scalar(select(App).where(App.id == trigger_data.app_id))
if not app_model:
raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}")
# 2. Get workflow
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id)
# 3. Get dispatcher based on tenant subscription
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
# 4. Rate limiting check will be done without timezone first
# 5. Determine user role and ID
if isinstance(user, Account):
created_by_role = CreatorUserRole.ACCOUNT
created_by = user.id
else: # EndUser
created_by_role = CreatorUserRole.END_USER
created_by = user.id
# 6. Create trigger log entry first (for tracking)
trigger_log = WorkflowTriggerLog(
tenant_id=trigger_data.tenant_id,
app_id=trigger_data.app_id,
workflow_id=workflow.id,
root_node_id=trigger_data.root_node_id,
trigger_metadata=(
trigger_data.trigger_metadata.model_dump_json() if trigger_data.trigger_metadata else "{}"
),
trigger_type=trigger_data.trigger_type,
workflow_run_id=None,
outputs=None,
trigger_data=trigger_data.model_dump_json(),
inputs=json.dumps(dict(trigger_data.inputs)),
status=WorkflowTriggerStatus.PENDING,
queue_name=dispatcher.get_queue_name(),
retry_count=0,
created_by_role=created_by_role,
created_by=created_by,
celery_task_id=None,
error=None,
elapsed_time=None,
total_tokens=None,
)
trigger_log = trigger_log_repo.create(trigger_log)
session.commit()
# 7. Check and consume quota
try:
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
except QuotaExceededError as e:
# Update trigger log status
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
trigger_log.error = f"Quota limit reached: {e}"
trigger_log_repo.update(trigger_log)
session.commit()
raise InvokeRateLimitError(
f"Workflow execution quota limit reached for tenant {trigger_data.tenant_id}"
) from e
# 8. Create task data
queue_name = dispatcher.get_queue_name()
task_data = WorkflowTaskData(workflow_trigger_log_id=trigger_log.id)
# 9. Dispatch to appropriate queue
task_data_dict = task_data.model_dump(mode="json")
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict) # type: ignore
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict) # type: ignore
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict) # type: ignore
# 10. Update trigger log with task info
trigger_log.status = WorkflowTriggerStatus.QUEUED
trigger_log.celery_task_id = task.id
trigger_log.triggered_at = datetime.now(UTC)
trigger_log_repo.update(trigger_log)
session.commit()
return AsyncTriggerResponse(
workflow_trigger_log_id=trigger_log.id,
task_id=task.id, # type: ignore
status="queued",
queue=queue_name,
)
@classmethod
def reinvoke_trigger(
cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str
) -> AsyncTriggerResponse:
"""
Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK
Updates the existing trigger log to retry status and creates a new async execution.
Returns immediately after queuing the retry, not after execution completion.
Args:
session: Database session to use for operations
user: User (Account or EndUser) who initiated the retry
workflow_trigger_log_id: ID of the trigger log to re-invoke
Returns:
AsyncTriggerResponse with new execution information (status="queued")
Note: This creates a new trigger log entry for the retry attempt
Raises:
ValueError: If trigger log not found
Behavior:
- Non-blocking: Returns immediately after queuing retry
- Creates new trigger log: Original log marked as retrying, new log for execution
- Preserves original trigger data: Uses same inputs and configuration
"""
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id)
if not trigger_log:
raise ValueError(f"Trigger log not found: {workflow_trigger_log_id}")
# Reconstruct trigger data from log
trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data)
# Reset log for retry
trigger_log.status = WorkflowTriggerStatus.RETRYING
trigger_log.retry_count += 1
trigger_log.error = None
trigger_log.triggered_at = datetime.now(UTC)
trigger_log_repo.update(trigger_log)
session.commit()
# Re-trigger workflow (this will create a new trigger log)
return cls.trigger_workflow_async(session, user, trigger_data)
@classmethod
def get_trigger_log(cls, workflow_trigger_log_id: str, tenant_id: str | None = None) -> dict[str, Any] | None:
"""
Get trigger log by ID
Args:
workflow_trigger_log_id: ID of the trigger log
tenant_id: Optional tenant ID for security check
Returns:
Trigger log as dictionary or None if not found
"""
with Session(db.engine) as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id)
if not trigger_log:
return None
return trigger_log.to_dict()
@classmethod
def get_recent_logs(
cls, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
) -> list[dict[str, Any]]:
"""
Get recent trigger logs
Args:
tenant_id: Tenant ID
app_id: Application ID
hours: Number of hours to look back
limit: Maximum number of results
offset: Number of results to skip
Returns:
List of trigger logs as dictionaries
"""
with Session(db.engine) as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
logs = trigger_log_repo.get_recent_logs(
tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset
)
return [log.to_dict() for log in logs]
@classmethod
def get_failed_logs_for_retry(
cls, tenant_id: str, max_retry_count: int = 3, limit: int = 100
) -> list[dict[str, Any]]:
"""
Get failed logs eligible for retry
Args:
tenant_id: Tenant ID
max_retry_count: Maximum retry count
limit: Maximum number of results
Returns:
List of failed trigger logs as dictionaries
"""
with Session(db.engine) as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
logs = trigger_log_repo.get_failed_for_retry(
tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit
)
return [log.to_dict() for log in logs]
@staticmethod
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow:
"""
Get workflow for the app
Args:
app_model: App model instance
workflow_id: Optional specific workflow ID
Returns:
Workflow instance
Raises:
WorkflowNotFoundError: If workflow not found
"""
if workflow_id:
# Get specific published workflow
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id)
if not workflow:
raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}")
else:
# Get default published workflow
workflow = workflow_service.get_published_workflow(app_model)
if not workflow:
raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}")
return workflow

View File

@@ -0,0 +1,165 @@
import io
import logging
import uuid
from collections.abc import Generator
from flask import Response, stream_with_context
from werkzeug.datastructures import FileStorage
from constants import AUDIO_EXTENSIONS
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from models.enums import MessageStatus
from models.model import App, AppMode, Message
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
ProviderNotSupportTextToSpeechServiceError,
UnsupportedAudioTypeServiceError,
)
from services.workflow_service import WorkflowService
FILE_SIZE = 30
FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024
logger = logging.getLogger(__name__)
class AudioService:
@classmethod
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: str | None = None):
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise ValueError("Speech to text is not enabled")
features_dict = workflow.features_dict
if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"):
raise ValueError("Speech to text is not enabled")
else:
app_model_config = app_model.app_model_config
if not app_model_config:
raise ValueError("Speech to text is not enabled")
if not app_model_config.speech_to_text_dict["enabled"]:
raise ValueError("Speech to text is not enabled")
if file is None:
raise NoAudioUploadedServiceError()
extension = file.mimetype
if extension not in [f"audio/{ext}" for ext in AUDIO_EXTENSIONS]:
raise UnsupportedAudioTypeServiceError()
file_content = file.read()
file_size = len(file_content)
if file_size > FILE_SIZE_LIMIT:
message = f"Audio size larger than {FILE_SIZE} mb"
raise AudioTooLargeServiceError(message)
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT
)
if model_instance is None:
raise ProviderNotSupportSpeechToTextServiceError()
buffer = io.BytesIO(file_content)
buffer.name = "temp.mp3"
return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)}
@classmethod
def transcript_tts(
cls,
app_model: App,
text: str | None = None,
voice: str | None = None,
end_user: str | None = None,
message_id: str | None = None,
is_draft: bool = False,
):
def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False):
if voice is None:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
if is_draft:
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
else:
workflow = app_model.workflow
if (
workflow is None
or "text_to_speech" not in workflow.features_dict
or not workflow.features_dict["text_to_speech"].get("enabled")
):
raise ValueError("TTS is not enabled")
voice = workflow.features_dict["text_to_speech"].get("voice")
else:
if not is_draft:
if app_model.app_model_config is None:
raise ValueError("AppModelConfig not found")
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
if not text_to_speech_dict.get("enabled"):
raise ValueError("TTS is not enabled")
voice = text_to_speech_dict.get("voice")
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
)
try:
if not voice:
voices = model_instance.get_tts_voices()
if voices:
voice = voices[0].get("value")
if not voice:
raise ValueError("Sorry, no voice available.")
else:
raise ValueError("Sorry, no voice available.")
return model_instance.invoke_tts(
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
)
except Exception as e:
raise e
if message_id:
try:
uuid.UUID(message_id)
except ValueError:
return None
message = db.session.query(Message).where(Message.id == message_id).first()
if message is None:
return None
if message.answer == "" and message.status == MessageStatus.NORMAL:
return None
else:
response = invoke_tts(text_content=message.answer, app_model=app_model, voice=voice, is_draft=is_draft)
if isinstance(response, Generator):
return Response(stream_with_context(response), content_type="audio/mpeg")
return response
else:
if text is None:
raise ValueError("Text is required")
response = invoke_tts(text_content=text, app_model=app_model, voice=voice, is_draft=is_draft)
if isinstance(response, Generator):
return Response(stream_with_context(response), content_type="audio/mpeg")
return response
@classmethod
def transcript_tts_voices(cls, tenant_id: str, language: str):
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS)
if model_instance is None:
raise ProviderNotSupportTextToSpeechServiceError()
try:
return model_instance.get_tts_voices(language)
except Exception as e:
raise e

View File

View File

@@ -0,0 +1,10 @@
from abc import ABC, abstractmethod
class ApiKeyAuthBase(ABC):
def __init__(self, credentials: dict):
self.credentials = credentials
@abstractmethod
def validate_credentials(self):
raise NotImplementedError

View File

@@ -0,0 +1,29 @@
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.auth_type import AuthType
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
auth_factory = self.get_apikey_auth_factory(provider)
self.auth = auth_factory(credentials)
def validate_credentials(self):
return self.auth.validate_credentials()
@staticmethod
def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]:
match provider:
case AuthType.FIRECRAWL:
from services.auth.firecrawl.firecrawl import FirecrawlAuth
return FirecrawlAuth
case AuthType.WATERCRAWL:
from services.auth.watercrawl.watercrawl import WatercrawlAuth
return WatercrawlAuth
case AuthType.JINA:
from services.auth.jina.jina import JinaAuth
return JinaAuth
case _:
raise ValueError("Invalid provider")

View File

@@ -0,0 +1,77 @@
import json
from sqlalchemy import select
from core.helper import encrypter
from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str):
data_source_api_key_bindings = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
)
).all()
return data_source_api_key_bindings
@staticmethod
def create_provider_auth(tenant_id: str, args: dict):
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
if auth_result:
# Encrypt the api key
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
args["credentials"]["config"]["api_key"] = api_key
data_source_api_key_binding = DataSourceApiKeyAuthBinding(
tenant_id=tenant_id, category=args["category"], provider=args["provider"]
)
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
db.session.add(data_source_api_key_binding)
db.session.commit()
@staticmethod
def get_auth_credentials(tenant_id: str, category: str, provider: str):
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.category == category,
DataSourceApiKeyAuthBinding.provider == provider,
DataSourceApiKeyAuthBinding.disabled.is_(False),
)
.first()
)
if not data_source_api_key_bindings:
return None
if not data_source_api_key_bindings.credentials:
return None
credentials = json.loads(data_source_api_key_bindings.credentials)
return credentials
@staticmethod
def delete_provider_auth(tenant_id: str, binding_id: str):
data_source_api_key_binding = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
.first()
)
if data_source_api_key_binding:
db.session.delete(data_source_api_key_binding)
db.session.commit()
@classmethod
def validate_api_key_auth_args(cls, args):
if "category" not in args or not args["category"]:
raise ValueError("category is required")
if "provider" not in args or not args["provider"]:
raise ValueError("provider is required")
if "credentials" not in args or not args["credentials"]:
raise ValueError("credentials is required")
if not isinstance(args["credentials"], dict):
raise ValueError("credentials must be a dictionary")
if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
raise ValueError("auth_type is required")

View File

@@ -0,0 +1,7 @@
from enum import StrEnum
class AuthType(StrEnum):
FIRECRAWL = "firecrawl"
WATERCRAWL = "watercrawl"
JINA = "jinareader"

View File

@@ -0,0 +1,49 @@
import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer")
self.api_key = credentials.get("config", {}).get("api_key", None)
self.base_url = credentials.get("config", {}).get("base_url", "https://api.firecrawl.dev")
if not self.api_key:
raise ValueError("No API key provided")
def validate_credentials(self):
headers = self._prepare_headers()
options = {
"url": "https://example.com",
"includePaths": [],
"excludePaths": [],
"limit": 1,
"scrapeOptions": {"onlyMainContent": True},
}
response = self._post_request(f"{self.base_url}/v1/crawl", options, headers)
if response.status_code == 200:
return True
else:
self._handle_error(response)
def _prepare_headers(self):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
else:
if response.text:
error_message = json.loads(response.text).get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")

View File

@@ -0,0 +1,44 @@
import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
self.api_key = credentials.get("config", {}).get("api_key", None)
if not self.api_key:
raise ValueError("No API key provided")
def validate_credentials(self):
headers = self._prepare_headers()
options = {
"url": "https://example.com",
}
response = self._post_request("https://r.jina.ai", options, headers)
if response.status_code == 200:
return True
else:
self._handle_error(response)
def _prepare_headers(self):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
else:
if response.text:
error_message = json.loads(response.text).get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")

View File

View File

@@ -0,0 +1,44 @@
import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
self.api_key = credentials.get("config", {}).get("api_key", None)
if not self.api_key:
raise ValueError("No API key provided")
def validate_credentials(self):
headers = self._prepare_headers()
options = {
"url": "https://example.com",
}
response = self._post_request("https://r.jina.ai", options, headers)
if response.status_code == 200:
return True
else:
self._handle_error(response)
def _prepare_headers(self):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
else:
if response.text:
error_message = json.loads(response.text).get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")

View File

@@ -0,0 +1,44 @@
import json
from urllib.parse import urljoin
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
class WatercrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "x-api-key":
raise ValueError("Invalid auth type, WaterCrawl auth type must be x-api-key")
self.api_key = credentials.get("config", {}).get("api_key", None)
self.base_url = credentials.get("config", {}).get("base_url", "https://app.watercrawl.dev")
if not self.api_key:
raise ValueError("No API key provided")
def validate_credentials(self):
headers = self._prepare_headers()
url = urljoin(self.base_url, "/api/v1/core/crawl-requests/")
response = self._get_request(url, headers)
if response.status_code == 200:
return True
else:
self._handle_error(response)
def _prepare_headers(self):
return {"Content-Type": "application/json", "X-API-KEY": self.api_key}
def _get_request(self, url, headers):
return httpx.get(url, headers=headers)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
else:
if response.text:
error_message = json.loads(response.text).get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")

View File

@@ -0,0 +1,241 @@
import os
from typing import Literal
import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
from werkzeug.exceptions import InternalServerError
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.helper import RateLimiter
from models import Account, TenantAccountJoin, TenantAccountRole
class BillingService:
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60)
@classmethod
def get_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
billing_info = cls._send_request("GET", "/subscription/info", params=params)
return billing_info
@classmethod
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
return usage_info
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)
return {
"limit": knowledge_rate_limit.get("limit", 10),
"subscription_plan": knowledge_rate_limit.get("subscription_plan", CloudPlan.SANDBOX),
}
@classmethod
def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""):
params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id}
return cls._send_request("GET", "/subscription/payment-link", params=params)
@classmethod
def get_model_provider_payment_link(cls, provider_name: str, tenant_id: str, account_id: str, prefilled_email: str):
params = {
"provider_name": provider_name,
"tenant_id": tenant_id,
"account_id": account_id,
"prefilled_email": prefilled_email,
}
return cls._send_request("GET", "/model-provider/payment-link", params=params)
@classmethod
def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""):
params = {"prefilled_email": prefilled_email, "tenant_id": tenant_id}
return cls._send_request("GET", "/invoices", params=params)
@classmethod
def update_tenant_feature_plan_usage(cls, tenant_id: str, feature_key: str, delta: int) -> dict:
"""
Update tenant feature plan usage.
Args:
tenant_id: Tenant identifier
feature_key: Feature key (e.g., 'trigger', 'workflow')
delta: Usage delta (positive to add, negative to consume)
Returns:
Response dict with 'result' and 'history_id'
Example: {"result": "success", "history_id": "uuid"}
"""
return cls._send_request(
"POST",
"/tenant-feature-usage/usage",
params={"tenant_id": tenant_id, "feature_key": feature_key, "delta": delta},
)
@classmethod
def refund_tenant_feature_plan_usage(cls, history_id: str) -> dict:
"""
Refund a previous usage charge.
Args:
history_id: The history_id returned from update_tenant_feature_plan_usage
Returns:
Response dict with 'result' and 'history_id'
"""
return cls._send_request("POST", "/tenant-feature-usage/refund", params={"quota_usage_history_id": history_id})
@classmethod
def get_tenant_feature_plan_usage(cls, tenant_id: str, feature_key: str):
params = {"tenant_id": tenant_id, "feature_key": feature_key}
return cls._send_request("GET", "/billing/tenant_feature_plan/usage", params=params)
@classmethod
@retry(
wait=wait_fixed(2),
stop=stop_before_delay(10),
retry=retry_if_exception_type(httpx.RequestError),
reraise=True,
)
def _send_request(cls, method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, json=None, params=None):
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = httpx.request(method, url, json=json, params=params, headers=headers)
if method == "GET" and response.status_code != httpx.codes.OK:
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
if method == "PUT":
if response.status_code == httpx.codes.INTERNAL_SERVER_ERROR:
raise InternalServerError(
"Unable to process billing request. Please try again later or contact support."
)
if response.status_code != httpx.codes.OK:
raise ValueError("Invalid arguments.")
if method == "POST" and response.status_code != httpx.codes.OK:
raise ValueError(f"Unable to send request to {url}. Please try again later or contact support.")
return response.json()
@staticmethod
def is_tenant_owner_or_admin(current_user: Account):
tenant_id = current_user.current_tenant_id
join: TenantAccountJoin | None = (
db.session.query(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
.first()
)
if not join:
raise ValueError("Tenant account join not found")
if not TenantAccountRole.is_privileged_role(TenantAccountRole(join.role)):
raise ValueError("Only team owner or team admin can perform this action")
@classmethod
def delete_account(cls, account_id: str):
"""Delete account."""
params = {"account_id": account_id}
return cls._send_request("DELETE", "/account/", params=params)
@classmethod
def is_email_in_freeze(cls, email: str) -> bool:
params = {"email": email}
try:
response = cls._send_request("GET", "/account/in-freeze", params=params)
return bool(response.get("data", False))
except Exception:
return False
@classmethod
def update_account_deletion_feedback(cls, email: str, feedback: str):
"""Update account deletion feedback."""
json = {"email": email, "feedback": feedback}
return cls._send_request("POST", "/account/delete-feedback", json=json)
class EducationIdentity:
verification_rate_limit = RateLimiter(prefix="edu_verification_rate_limit", max_attempts=10, time_window=60)
activation_rate_limit = RateLimiter(prefix="edu_activation_rate_limit", max_attempts=10, time_window=60)
@classmethod
def verify(cls, account_id: str, account_email: str):
if cls.verification_rate_limit.is_rate_limited(account_email):
from controllers.console.error import EducationVerifyLimitError
raise EducationVerifyLimitError()
cls.verification_rate_limit.increment_rate_limit(account_email)
params = {"account_id": account_id}
return BillingService._send_request("GET", "/education/verify", params=params)
@classmethod
def status(cls, account_id: str):
params = {"account_id": account_id}
return BillingService._send_request("GET", "/education/status", params=params)
@classmethod
def activate(cls, account: Account, token: str, institution: str, role: str):
if cls.activation_rate_limit.is_rate_limited(account.email):
from controllers.console.error import EducationActivateLimitError
raise EducationActivateLimitError()
cls.activation_rate_limit.increment_rate_limit(account.email)
params = {"account_id": account.id, "curr_tenant_id": account.current_tenant_id}
json = {
"institution": institution,
"token": token,
"role": role,
}
return BillingService._send_request("POST", "/education/", json=json, params=params)
@classmethod
def autocomplete(cls, keywords: str, page: int = 0, limit: int = 20):
params = {"keywords": keywords, "page": page, "limit": limit}
return BillingService._send_request("GET", "/education/autocomplete", params=params)
@classmethod
def get_compliance_download_link(
cls,
doc_name: str,
account_id: str,
tenant_id: str,
ip: str,
device_info: str,
):
limiter_key = f"{account_id}:{tenant_id}"
if cls.compliance_download_rate_limiter.is_rate_limited(limiter_key):
from controllers.console.error import ComplianceRateLimitError
raise ComplianceRateLimitError()
json = {
"doc_name": doc_name,
"account_id": account_id,
"tenant_id": tenant_id,
"ip_address": ip,
"device_info": device_info,
}
res = cls._send_request("POST", "/compliance/download", json=json)
cls.compliance_download_rate_limiter.increment_rate_limit(limiter_key)
return res
@classmethod
def clean_billing_info_cache(cls, tenant_id: str):
redis_client.delete(f"tenant:{tenant_id}:billing_info")
@classmethod
def sync_partner_tenants_bindings(cls, account_id: str, partner_key: str, click_id: str):
payload = {"account_id": account_id, "click_id": click_id}
return cls._send_request("PUT", f"/partners/{partner_key}/tenants", json=payload)

View File

@@ -0,0 +1,466 @@
import datetime
import json
import logging
import time
from concurrent.futures import ThreadPoolExecutor
import click
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Tenant
from models.model import (
App,
AppAnnotationHitHistory,
Conversation,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from models.workflow import WorkflowAppLog
from repositories.factory import DifyAPIRepositoryFactory
from services.billing_service import BillingService
logger = logging.getLogger(__name__)
class ClearFreePlanTenantExpiredLogs:
@classmethod
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]):
"""
Clean up message-related tables to avoid data redundancy.
This method cleans up tables that have foreign key relationships with Message.
Args:
session: Database session, the same with the one in process_tenant method
tenant_id: Tenant ID for logging purposes
batch_message_ids: List of message IDs to clean up
"""
if not batch_message_ids:
return
# Clean up each related table
related_tables = [
(MessageFeedback, "message_feedbacks"),
(MessageFile, "message_files"),
(MessageAnnotation, "message_annotations"),
(MessageChain, "message_chains"),
(MessageAgentThought, "message_agent_thoughts"),
(AppAnnotationHitHistory, "app_annotation_hit_histories"),
(SavedMessage, "saved_messages"),
]
for model, table_name in related_tables:
# Query records related to expired messages
records = (
session.query(model)
.where(
model.message_id.in_(batch_message_ids), # type: ignore
)
.all()
)
if len(records) == 0:
continue
# Save records before deletion
record_ids = [record.id for record in records]
try:
record_data = []
for record in records:
try:
if hasattr(record, "to_dict"):
record_data.append(record.to_dict())
else:
# if record doesn't have to_dict method, we need to transform it to dict manually
record_dict = {}
for column in record.__table__.columns:
record_dict[column.name] = getattr(record, column.name)
record_data.append(record_dict)
except Exception:
logger.exception("Failed to transform %s record: %s", table_name, record.id)
continue
if record_data:
storage.save(
f"free_plan_tenant_expired_logs/"
f"{tenant_id}/{table_name}/{datetime.datetime.now().strftime('%Y-%m-%d')}"
f"-{time.time()}.json",
json.dumps(
jsonable_encoder(record_data),
).encode("utf-8"),
)
except Exception:
logger.exception("Failed to save %s records", table_name)
session.query(model).where(
model.id.in_(record_ids), # type: ignore
).delete(synchronize_session=False)
click.echo(
click.style(
f"[{datetime.datetime.now()}] Processed {len(record_ids)} "
f"{table_name} records for tenant {tenant_id}"
)
)
@classmethod
def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
with flask_app.app_context():
apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all()
app_ids = [app.id for app in apps]
while True:
with Session(db.engine).no_autoflush as session:
messages = (
session.query(Message)
.where(
Message.app_id.in_(app_ids),
Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
)
if len(messages) == 0:
break
storage.save(
f"free_plan_tenant_expired_logs/"
f"{tenant_id}/messages/{datetime.datetime.now().strftime('%Y-%m-%d')}"
f"-{time.time()}.json",
json.dumps(
jsonable_encoder(
[message.to_dict() for message in messages],
),
).encode("utf-8"),
)
message_ids = [message.id for message in messages]
# delete messages
session.query(Message).where(
Message.id.in_(message_ids),
).delete(synchronize_session=False)
cls._clear_message_related_tables(session, tenant_id, message_ids)
session.commit()
click.echo(
click.style(
f"[{datetime.datetime.now()}] Processed {len(message_ids)} messages for tenant {tenant_id} "
)
)
while True:
with Session(db.engine).no_autoflush as session:
conversations = (
session.query(Conversation)
.where(
Conversation.app_id.in_(app_ids),
Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
)
if len(conversations) == 0:
break
storage.save(
f"free_plan_tenant_expired_logs/"
f"{tenant_id}/conversations/{datetime.datetime.now().strftime('%Y-%m-%d')}"
f"-{time.time()}.json",
json.dumps(
jsonable_encoder(
[conversation.to_dict() for conversation in conversations],
),
).encode("utf-8"),
)
conversation_ids = [conversation.id for conversation in conversations]
session.query(Conversation).where(
Conversation.id.in_(conversation_ids),
).delete(synchronize_session=False)
session.commit()
click.echo(
click.style(
f"[{datetime.datetime.now()}] Processed {len(conversation_ids)}"
f" conversations for tenant {tenant_id}"
)
)
# Process expired workflow node executions with backup
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
before_date = datetime.datetime.now() - datetime.timedelta(days=days)
total_deleted = 0
while True:
# Get a batch of expired executions for backup
workflow_node_executions = node_execution_repo.get_expired_executions_batch(
tenant_id=tenant_id,
before_date=before_date,
batch_size=batch,
)
if len(workflow_node_executions) == 0:
break
# Save workflow node executions to storage
storage.save(
f"free_plan_tenant_expired_logs/"
f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}"
f"-{time.time()}.json",
json.dumps(
jsonable_encoder(workflow_node_executions),
).encode("utf-8"),
)
# Extract IDs for deletion
workflow_node_execution_ids = [
workflow_node_execution.id for workflow_node_execution in workflow_node_executions
]
# Delete the backed up executions
deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids)
total_deleted += deleted_count
click.echo(
click.style(
f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}"
f" workflow node executions for tenant {tenant_id}"
)
)
# If we got fewer than the batch size, we're done
if len(workflow_node_executions) < batch:
break
# Process expired workflow runs with backup
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
before_date = datetime.datetime.now() - datetime.timedelta(days=days)
total_deleted = 0
while True:
# Get a batch of expired workflow runs for backup
workflow_runs = workflow_run_repo.get_expired_runs_batch(
tenant_id=tenant_id,
before_date=before_date,
batch_size=batch,
)
if len(workflow_runs) == 0:
break
# Save workflow runs to storage
storage.save(
f"free_plan_tenant_expired_logs/"
f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
f"-{time.time()}.json",
json.dumps(
jsonable_encoder(
[workflow_run.to_dict() for workflow_run in workflow_runs],
),
).encode("utf-8"),
)
# Extract IDs for deletion
workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs]
# Delete the backed up workflow runs
deleted_count = workflow_run_repo.delete_runs_by_ids(workflow_run_ids)
total_deleted += deleted_count
click.echo(
click.style(
f"[{datetime.datetime.now()}] Processed {len(workflow_run_ids)}"
f" workflow runs for tenant {tenant_id}"
)
)
# If we got fewer than the batch size, we're done
if len(workflow_runs) < batch:
break
while True:
with Session(db.engine).no_autoflush as session:
workflow_app_logs = (
session.query(WorkflowAppLog)
.where(
WorkflowAppLog.tenant_id == tenant_id,
WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
)
if len(workflow_app_logs) == 0:
break
# save workflow app logs
storage.save(
f"free_plan_tenant_expired_logs/"
f"{tenant_id}/workflow_app_logs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
f"-{time.time()}.json",
json.dumps(
jsonable_encoder(
[workflow_app_log.to_dict() for workflow_app_log in workflow_app_logs],
),
).encode("utf-8"),
)
workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
# delete workflow app logs
session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete(
synchronize_session=False
)
session.commit()
click.echo(
click.style(
f"[{datetime.datetime.now()}] Processed {len(workflow_app_log_ids)}"
f" workflow app logs for tenant {tenant_id}"
)
)
@classmethod
def process(cls, days: int, batch: int, tenant_ids: list[str]):
"""
Clear free plan tenant expired logs.
"""
click.echo(click.style("Clearing free plan tenant expired logs", fg="white"))
ended_at = datetime.datetime.now()
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
current_time = started_at
with Session(db.engine) as session:
total_tenant_count = session.query(Tenant.id).count()
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
handled_tenant_count = 0
thread_pool = ThreadPoolExecutor(max_workers=10)
def process_tenant(flask_app: Flask, tenant_id: str):
try:
if (
not dify_config.BILLING_ENABLED
or BillingService.get_info(tenant_id)["subscription"]["plan"] == CloudPlan.SANDBOX
):
# only process sandbox tenant
cls.process_tenant(flask_app, tenant_id, days, batch)
except Exception:
logger.exception("Failed to process tenant %s", tenant_id)
finally:
nonlocal handled_tenant_count
handled_tenant_count += 1
if handled_tenant_count % 100 == 0:
click.echo(
click.style(
f"[{datetime.datetime.now()}] "
f"Processed {handled_tenant_count} tenants "
f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
f"{handled_tenant_count}/{total_tenant_count}",
fg="green",
)
)
futures = []
if tenant_ids:
for tenant_id in tenant_ids:
futures.append(
thread_pool.submit(
process_tenant,
current_app._get_current_object(), # type: ignore[attr-defined]
tenant_id,
)
)
else:
while current_time < ended_at:
click.echo(
click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white")
)
# Initial interval of 1 day, will be dynamically adjusted based on tenant count
interval = datetime.timedelta(days=1)
# Process tenants in this batch
with Session(db.engine) as session:
# Calculate tenant count in next batch with current interval
# Try different intervals until we find one with a reasonable tenant count
test_intervals = [
datetime.timedelta(days=1),
datetime.timedelta(hours=12),
datetime.timedelta(hours=6),
datetime.timedelta(hours=3),
datetime.timedelta(hours=1),
]
tenant_count = 0
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)
.where(Tenant.created_at.between(current_time, current_time + test_interval))
.count()
)
if tenant_count <= 100:
interval = test_interval
break
else:
# If all intervals have too many tenants, use minimum interval
interval = datetime.timedelta(hours=1)
# Adjust interval to target ~100 tenants per batch
if tenant_count > 0:
# Scale interval based on ratio to target count
interval = min(
datetime.timedelta(days=1), # Max 1 day
max(
datetime.timedelta(hours=1), # Min 1 hour
interval * (100 / tenant_count), # Scale to target 100
),
)
batch_end = min(current_time + interval, ended_at)
rs = (
session.query(Tenant.id)
.where(Tenant.created_at.between(current_time, batch_end))
.order_by(Tenant.created_at)
)
tenants = []
for row in rs:
tenant_id = str(row.id)
try:
tenants.append(tenant_id)
except Exception:
logger.exception("Failed to process tenant %s", tenant_id)
continue
futures.append(
thread_pool.submit(
process_tenant,
current_app._get_current_object(), # type: ignore[attr-defined]
tenant_id,
)
)
current_time = batch_end
# wait for all threads to finish
for future in futures:
future.result()

View File

@@ -0,0 +1,16 @@
from extensions.ext_code_based_extension import code_based_extension
class CodeBasedExtensionService:
@staticmethod
def get_code_based_extension(module: str):
module_extensions = code_based_extension.module_extensions(module)
return [
{
"name": module_extension.name,
"label": module_extension.label,
"form_schema": module_extension.form_schema,
}
for module_extension in module_extensions
if not module_extension.builtin
]

View File

@@ -0,0 +1,326 @@
import contextlib
import logging
from collections.abc import Callable, Sequence
from typing import Any, Union
from sqlalchemy import asc, desc, func, or_, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
from core.variables.types import SegmentType
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from extensions.ext_database import db
from factories import variable_factory
from libs.datetime_utils import naive_utc_now
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account, ConversationVariable
from models.model import App, Conversation, EndUser, Message
from services.errors.conversation import (
ConversationNotExistsError,
ConversationVariableNotExistsError,
ConversationVariableTypeMismatchError,
LastConversationNotExistsError,
)
from services.errors.message import MessageNotExistsError
from tasks.delete_conversation_task import delete_conversation_related_data
logger = logging.getLogger(__name__)
class ConversationService:
@classmethod
def pagination_by_last_id(
cls,
*,
session: Session,
app_model: App,
user: Union[Account, EndUser] | None,
last_id: str | None,
limit: int,
invoke_from: InvokeFrom,
include_ids: Sequence[str] | None = None,
exclude_ids: Sequence[str] | None = None,
sort_by: str = "-updated_at",
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
stmt = select(Conversation).where(
Conversation.is_deleted == False,
Conversation.app_id == app_model.id,
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
)
# Check if include_ids is not None to apply filter
if include_ids is not None:
if len(include_ids) == 0:
# If include_ids is empty, return empty result
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
stmt = stmt.where(Conversation.id.in_(include_ids))
# Check if exclude_ids is not None to apply filter
if exclude_ids is not None:
if len(exclude_ids) > 0:
stmt = stmt.where(~Conversation.id.in_(exclude_ids))
# define sort fields and directions
sort_field, sort_direction = cls._get_sort_params(sort_by)
if last_id:
last_conversation = session.scalar(stmt.where(Conversation.id == last_id))
if not last_conversation:
raise LastConversationNotExistsError()
# build filters based on sorting
filter_condition = cls._build_filter_condition(
sort_field=sort_field,
sort_direction=sort_direction,
reference_conversation=last_conversation,
)
stmt = stmt.where(filter_condition)
query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit)
conversations = session.scalars(query_stmt).all()
has_more = False
if len(conversations) == limit:
current_page_last_conversation = conversations[-1]
rest_filter_condition = cls._build_filter_condition(
sort_field=sort_field,
sort_direction=sort_direction,
reference_conversation=current_page_last_conversation,
)
count_stmt = select(func.count()).select_from(stmt.where(rest_filter_condition).subquery())
rest_count = session.scalar(count_stmt) or 0
if rest_count > 0:
has_more = True
return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more)
@classmethod
def _get_sort_params(cls, sort_by: str):
if sort_by.startswith("-"):
return sort_by[1:], desc
return sort_by, asc
@classmethod
def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation):
field_value = getattr(reference_conversation, sort_field)
if sort_direction is desc:
return getattr(Conversation, sort_field) < field_value
return getattr(Conversation, sort_field) > field_value
@classmethod
def rename(
cls,
app_model: App,
conversation_id: str,
user: Union[Account, EndUser] | None,
name: str,
auto_generate: bool,
):
conversation = cls.get_conversation(app_model, conversation_id, user)
if auto_generate:
return cls.auto_generate_name(app_model, conversation)
else:
conversation.name = name
conversation.updated_at = naive_utc_now()
db.session.commit()
return conversation
@classmethod
def auto_generate_name(cls, app_model: App, conversation: Conversation):
# get conversation first message
message = (
db.session.query(Message)
.where(Message.app_id == app_model.id, Message.conversation_id == conversation.id)
.order_by(Message.created_at.asc())
.first()
)
if not message:
raise MessageNotExistsError()
# generate conversation name
with contextlib.suppress(Exception):
name = LLMGenerator.generate_conversation_name(
app_model.tenant_id, message.query, conversation.id, app_model.id
)
conversation.name = name
db.session.commit()
return conversation
@classmethod
def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
conversation = (
db.session.query(Conversation)
.where(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
Conversation.is_deleted == False,
)
.first()
)
if not conversation:
raise ConversationNotExistsError()
return conversation
@classmethod
def delete(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
try:
logger.info(
"Initiating conversation deletion for app_name %s, conversation_id: %s",
app_model.name,
conversation_id,
)
db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
db.session.commit()
delete_conversation_related_data.delay(conversation_id)
except Exception as e:
db.session.rollback()
raise e
@classmethod
def get_conversational_variable(
cls,
app_model: App,
conversation_id: str,
user: Union[Account, EndUser] | None,
limit: int,
last_id: str | None,
) -> InfiniteScrollPagination:
conversation = cls.get_conversation(app_model, conversation_id, user)
stmt = (
select(ConversationVariable)
.where(ConversationVariable.app_id == app_model.id)
.where(ConversationVariable.conversation_id == conversation.id)
.order_by(ConversationVariable.created_at)
)
with Session(db.engine) as session:
if last_id:
last_variable = session.scalar(stmt.where(ConversationVariable.id == last_id))
if not last_variable:
raise ConversationVariableNotExistsError()
# Filter for variables created after the last_id
stmt = stmt.where(ConversationVariable.created_at > last_variable.created_at)
# Apply limit to query: fetch one extra row to determine has_more
query_stmt = stmt.limit(limit + 1)
rows = session.scalars(query_stmt).all()
has_more = False
if len(rows) > limit:
has_more = True
rows = rows[:limit] # Remove the extra item
variables = [
{
"created_at": row.created_at,
"updated_at": row.updated_at,
**row.to_variable().model_dump(),
}
for row in rows
]
return InfiniteScrollPagination(variables, limit, has_more)
@classmethod
def update_conversation_variable(
cls,
app_model: App,
conversation_id: str,
variable_id: str,
user: Union[Account, EndUser] | None,
new_value: Any,
):
"""
Update a conversation variable's value.
Args:
app_model: The app model
conversation_id: The conversation ID
variable_id: The variable ID to update
user: The user (Account or EndUser)
new_value: The new value for the variable
Returns:
Dictionary containing the updated variable information
Raises:
ConversationNotExistsError: If the conversation doesn't exist
ConversationVariableNotExistsError: If the variable doesn't exist
ConversationVariableTypeMismatchError: If the new value type doesn't match the variable's expected type
"""
# Verify conversation exists and user has access
conversation = cls.get_conversation(app_model, conversation_id, user)
# Get the existing conversation variable
stmt = (
select(ConversationVariable)
.where(ConversationVariable.app_id == app_model.id)
.where(ConversationVariable.conversation_id == conversation.id)
.where(ConversationVariable.id == variable_id)
)
with Session(db.engine) as session:
existing_variable = session.scalar(stmt)
if not existing_variable:
raise ConversationVariableNotExistsError()
# Convert existing variable to Variable object
current_variable = existing_variable.to_variable()
# Validate that the new value type matches the expected variable type
expected_type = SegmentType(current_variable.value_type)
# There is showing number in web ui but int in db
if expected_type == SegmentType.INTEGER:
expected_type = SegmentType.NUMBER
if not expected_type.is_valid(new_value):
inferred_type = SegmentType.infer_segment_type(new_value)
raise ConversationVariableTypeMismatchError(
f"Type mismatch: variable '{current_variable.name}' expects {expected_type.value}, "
f"but got {inferred_type.value if inferred_type else 'unknown'} type"
)
# Create updated variable with new value only, preserving everything else
updated_variable_dict = {
"id": current_variable.id,
"name": current_variable.name,
"description": current_variable.description,
"value_type": current_variable.value_type,
"value": new_value,
"selector": current_variable.selector,
}
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
# Use the conversation variable updater to persist the changes
updater = conversation_variable_updater_factory()
updater.update(conversation_id, updated_variable)
updater.flush()
# Return the updated variable data
return {
"created_at": existing_variable.created_at,
"updated_at": naive_utc_now(), # Update timestamp
**updated_variable.model_dump(),
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,990 @@
import logging
import time
from collections.abc import Mapping
from typing import Any
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper import encrypter
from core.helper.name_generator import generate_incremental_name
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.model_runtime.entities.provider_entities import FormType
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.datasource import PluginDatasourceManager
from core.plugin.impl.oauth import OAuthHandler
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
from models.provider_ids import DatasourceProviderID
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
def get_current_user():
from libs.login import current_user
from models.account import Account
from models.model import EndUser
if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore
raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}")
return current_user
class DatasourceProviderService:
"""
Model Provider Service
"""
def __init__(self) -> None:
self.provider_manager = PluginDatasourceManager()
def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID):
"""
remove oauth custom client params
"""
with Session(db.engine) as session:
session.query(DatasourceOauthTenantParamConfig).filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
).delete()
session.commit()
def decrypt_datasource_provider_credentials(
self,
tenant_id: str,
datasource_provider: DatasourceProvider,
plugin_id: str,
provider: str,
) -> dict[str, Any]:
encrypted_credentials = datasource_provider.encrypted_credentials
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
decrypted_credentials = encrypted_credentials.copy()
for key, value in decrypted_credentials.items():
if key in credential_secret_variables:
decrypted_credentials[key] = encrypter.decrypt_token(tenant_id, value)
return decrypted_credentials
def encrypt_datasource_provider_credentials(
self,
tenant_id: str,
provider: str,
plugin_id: str,
raw_credentials: Mapping[str, Any],
datasource_provider: DatasourceProvider,
) -> dict[str, Any]:
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
encrypted_credentials = dict(raw_credentials)
for key, value in encrypted_credentials.items():
if key in provider_credential_secret_variables:
encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
return encrypted_credentials
def get_datasource_credentials(
self,
tenant_id: str,
provider: str,
plugin_id: str,
credential_id: str | None = None,
) -> dict[str, Any]:
"""
get credential by id
"""
with Session(db.engine) as session:
if credential_id:
datasource_provider = (
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
)
else:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
)
if not datasource_provider:
return {}
# refresh the credentials
if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()):
current_user = get_current_user()
decrypted_credentials = self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
plugin_id=plugin_id,
provider=provider,
)
datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
provider_name = datasource_provider_id.provider_name
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
f"{datasource_provider_id}/datasource/callback"
)
system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
refreshed_credentials = OAuthHandler().refresh_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
plugin_id=datasource_provider_id.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
tenant_id=tenant_id,
raw_credentials=refreshed_credentials.credentials,
provider=provider,
plugin_id=plugin_id,
datasource_provider=datasource_provider,
)
datasource_provider.expires_at = refreshed_credentials.expires_at
session.commit()
return self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
plugin_id=plugin_id,
provider=provider,
)
def get_all_datasource_credentials_by_provider(
self,
tenant_id: str,
provider: str,
plugin_id: str,
) -> list[dict[str, Any]]:
"""
get all datasource credentials by provider
"""
with Session(db.engine) as session:
datasource_providers = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.all()
)
if not datasource_providers:
return []
current_user = get_current_user()
# refresh the credentials
real_credentials_list = []
for datasource_provider in datasource_providers:
decrypted_credentials = self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
plugin_id=plugin_id,
provider=provider,
)
datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
provider_name = datasource_provider_id.provider_name
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
f"{datasource_provider_id}/datasource/callback"
)
system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
refreshed_credentials = OAuthHandler().refresh_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
plugin_id=datasource_provider_id.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
tenant_id=tenant_id,
raw_credentials=refreshed_credentials.credentials,
provider=provider,
plugin_id=plugin_id,
datasource_provider=datasource_provider,
)
datasource_provider.expires_at = refreshed_credentials.expires_at
real_credentials = self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
plugin_id=plugin_id,
provider=provider,
)
real_credentials_list.append(real_credentials)
session.commit()
return real_credentials_list
def update_datasource_provider_name(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
):
"""
update datasource provider name
"""
with Session(db.engine) as session:
target_provider = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
id=credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if target_provider is None:
raise ValueError("provider not found")
if target_provider.name == name:
return
# check name is exist
if (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=name,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.count()
> 0
):
raise ValueError("Authorization name is already exists")
target_provider.name = name
session.commit()
return
def set_default_datasource_provider(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, credential_id: str
):
"""
set default datasource provider
"""
with Session(db.engine) as session:
# get provider
target_provider = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
id=credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=target_provider.provider,
plugin_id=target_provider.plugin_id,
is_default=True,
).update({"is_default": False})
# set new default provider
target_provider.is_default = True
session.commit()
return {"result": "success"}
def setup_oauth_custom_client_params(
self,
tenant_id: str,
datasource_provider_id: DatasourceProviderID,
client_params: dict | None,
enabled: bool | None,
):
"""
setup oauth custom client params
"""
if client_params is None and enabled is None:
return
with Session(db.engine) as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if not tenant_oauth_client_params:
tenant_oauth_client_params = DatasourceOauthTenantParamConfig(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
client_params={},
enabled=False,
)
session.add(tenant_oauth_client_params)
if client_params is not None:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
original_params = (
encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
)
new_params: dict = {
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
for key, value in client_params.items()
}
tenant_oauth_client_params.client_params = dict(encrypter.encrypt(new_params))
if enabled is not None:
tenant_oauth_client_params.enabled = enabled
session.commit()
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
"""
check if system oauth params exist
"""
with Session(db.engine).no_autoflush as session:
return (
session.query(DatasourceOauthParamConfig)
.filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
.first()
is not None
)
def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
"""
check if tenant oauth params is enabled
"""
return (
db.session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
enabled=True,
)
.count()
> 0
)
def get_tenant_oauth_client(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
) -> Mapping[str, Any] | None:
"""
get tenant oauth client
"""
tenant_oauth_client_params = (
db.session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
if mask:
return encrypter.mask_plugin_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
else:
return encrypter.decrypt(tenant_oauth_client_params.client_params)
return None
def get_oauth_encrypter(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
"""
get oauth encrypter
"""
datasource_provider = self.provider_manager.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
)
if not datasource_provider.declaration.oauth_schema:
raise ValueError("Datasource provider oauth schema not found")
client_schema = datasource_provider.declaration.oauth_schema.client_schema
return create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in client_schema],
cache=NoOpProviderCredentialCache(),
)
def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None:
"""
get oauth client
"""
provider = datasource_provider_id.provider_name
plugin_id = datasource_provider_id.plugin_id
with Session(db.engine).no_autoflush as session:
# get tenant oauth client params
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
enabled=True,
)
.first()
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
return dict(encrypter.decrypt(tenant_oauth_client_params.client_params))
provider_controller = self.provider_manager.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
)
is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
if is_verified:
# fallback to system oauth client params
oauth_client_params = (
session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
)
if oauth_client_params:
return oauth_client_params.system_credentials
raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}")
@staticmethod
def generate_next_datasource_provider_name(
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
) -> str:
db_providers = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
.all()
)
return generate_incremental_name(
[provider.name for provider in db_providers],
f"{credential_type.get_name()}",
)
def reauthorize_datasource_oauth_provider(
self,
name: str | None,
tenant_id: str,
provider_id: DatasourceProviderID,
avatar_url: str | None,
expire_at: int,
credentials: dict,
credential_id: str,
) -> None:
"""
update datasource oauth provider
"""
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
with redis_client.lock(lock, timeout=20):
target_provider = (
session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
)
if target_provider is None:
raise ValueError("provider not found")
db_provider_name = name
if not db_provider_name:
db_provider_name = target_provider.name
else:
name_conflict = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=CredentialType.OAUTH2.value,
)
.count()
)
if name_conflict > 0:
db_provider_name = generate_incremental_name(
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
],
db_provider_name,
)
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(tenant_id, value)
target_provider.expires_at = expire_at
target_provider.encrypted_credentials = credentials
target_provider.avatar_url = avatar_url or target_provider.avatar_url
session.commit()
def add_datasource_oauth_provider(
self,
name: str | None,
tenant_id: str,
provider_id: DatasourceProviderID,
avatar_url: str | None,
expire_at: int,
credentials: dict,
) -> None:
"""
add datasource oauth provider
"""
credential_type = CredentialType.OAUTH2
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
with redis_client.lock(lock, timeout=60):
db_provider_name = name
if not db_provider_name:
db_provider_name = self.generate_next_datasource_provider_name(
session=session,
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=credential_type,
)
else:
if (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
)
.count()
> 0
):
db_provider_name = generate_incremental_name(
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
],
db_provider_name,
)
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
encrypted_credentials=credentials,
avatar_url=avatar_url or "default",
expires_at=expire_at,
)
session.add(datasource_provider)
session.commit()
def add_datasource_api_key_provider(
self,
name: str | None,
tenant_id: str,
provider_id: DatasourceProviderID,
credentials: dict,
) -> None:
"""
validate datasource provider credentials.
:param tenant_id:
:param provider:
:param credentials:
"""
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
with redis_client.lock(lock, timeout=20):
db_provider_name = name or self.generate_next_datasource_provider_name(
session=session,
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=CredentialType.API_KEY,
)
# check name is exist
if (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name)
.count()
> 0
):
raise ValueError("Authorization name is already exists")
try:
current_user = get_current_user()
self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
provider=provider_name,
plugin_id=plugin_id,
credentials=credentials,
)
except Exception as e:
raise ValueError(f"Failed to validate credentials: {str(e)}")
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_name,
plugin_id=plugin_id,
auth_type=CredentialType.API_KEY,
encrypted_credentials=credentials,
)
session.add(datasource_provider)
session.commit()
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
"""
Extract secret input form variables.
:param credential_form_schemas:
:return:
"""
datasource_provider = self.provider_manager.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=provider_id
)
credential_form_schemas = []
if credential_type == CredentialType.API_KEY:
credential_form_schemas = list(datasource_provider.declaration.credentials_schema)
elif credential_type == CredentialType.OAUTH2:
if not datasource_provider.declaration.oauth_schema:
raise ValueError("Datasource provider oauth schema not found")
credential_form_schemas = list(datasource_provider.declaration.oauth_schema.credentials_schema)
else:
raise ValueError(f"Invalid credential type: {credential_type}")
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.name)
return secret_input_form_variables
def list_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
"""
list datasource credentials with obfuscated sensitive fields.
:param tenant_id: workspace id
:param provider_id: provider id
:return:
"""
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.all()
)
if not datasource_providers:
return []
copy_credentials_list = []
default_provider = (
db.session.query(DatasourceProvider.id)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
)
default_provider_id = default_provider.id if default_provider else None
for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.obfuscated_token(value)
copy_credentials_list.append(
{
"credential": copy_credentials,
"type": datasource_provider.auth_type,
"name": datasource_provider.name,
"avatar_url": datasource_provider.avatar_url,
"id": datasource_provider.id,
"is_default": default_provider_id and datasource_provider.id == default_provider_id,
}
)
return copy_credentials_list
def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]:
"""
get datasource credentials.
:return:
"""
# get all plugin providers
manager = PluginDatasourceManager()
datasources = manager.fetch_installed_datasource_providers(tenant_id)
datasource_credentials = []
for datasource in datasources:
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
credentials = self.list_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
)
datasource_credentials.append(
{
"provider": datasource.provider,
"plugin_id": datasource.plugin_id,
"plugin_unique_identifier": datasource.plugin_unique_identifier,
"icon": datasource.declaration.identity.icon,
"name": datasource.declaration.identity.name.split("/")[-1],
"label": datasource.declaration.identity.label.model_dump(),
"description": datasource.declaration.identity.description.model_dump(),
"author": datasource.declaration.identity.author,
"credentials_list": credentials,
"credential_schema": [
credential.model_dump() for credential in datasource.declaration.credentials_schema
],
"oauth_schema": {
"client_schema": [
client_schema.model_dump()
for client_schema in datasource.declaration.oauth_schema.client_schema
],
"credentials_schema": [
credential_schema.model_dump()
for credential_schema in datasource.declaration.oauth_schema.credentials_schema
],
"oauth_custom_client_params": self.get_tenant_oauth_client(
tenant_id, datasource_provider_id, mask=True
),
"is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
tenant_id, datasource_provider_id
),
"is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
"redirect_uri": redirect_uri,
}
if datasource.declaration.oauth_schema
else None,
}
)
return datasource_credentials
def get_hard_code_datasource_credentials(self, tenant_id: str) -> list[dict]:
"""
get hard code datasource credentials.
:return:
"""
# get all plugin providers
manager = PluginDatasourceManager()
datasources = manager.fetch_installed_datasource_providers(tenant_id)
datasource_credentials = []
for datasource in datasources:
if datasource.plugin_id in [
"langgenius/firecrawl_datasource",
"langgenius/notion_datasource",
"langgenius/jina_datasource",
]:
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
credentials = self.list_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
redirect_uri = "{}/console/api/oauth/plugin/{}/datasource/callback".format(
dify_config.CONSOLE_API_URL, datasource_provider_id
)
datasource_credentials.append(
{
"provider": datasource.provider,
"plugin_id": datasource.plugin_id,
"plugin_unique_identifier": datasource.plugin_unique_identifier,
"icon": datasource.declaration.identity.icon,
"name": datasource.declaration.identity.name.split("/")[-1],
"label": datasource.declaration.identity.label.model_dump(),
"description": datasource.declaration.identity.description.model_dump(),
"author": datasource.declaration.identity.author,
"credentials_list": credentials,
"credential_schema": [
credential.model_dump() for credential in datasource.declaration.credentials_schema
],
"oauth_schema": {
"client_schema": [
client_schema.model_dump()
for client_schema in datasource.declaration.oauth_schema.client_schema
],
"credentials_schema": [
credential_schema.model_dump()
for credential_schema in datasource.declaration.oauth_schema.credentials_schema
],
"oauth_custom_client_params": self.get_tenant_oauth_client(
tenant_id, datasource_provider_id, mask=True
),
"is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
tenant_id, datasource_provider_id
),
"is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
"redirect_uri": redirect_uri,
}
if datasource.declaration.oauth_schema
else None,
}
)
return datasource_credentials
def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
"""
get datasource credentials.
:param tenant_id: workspace id
:param provider_id: provider id
:return:
"""
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.all()
)
if not datasource_providers:
return []
copy_credentials_list = []
for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
copy_credentials_list.append(
{
"credentials": copy_credentials,
"type": datasource_provider.auth_type,
}
)
return copy_credentials_list
def update_datasource_credentials(
self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None
) -> None:
"""
update datasource credentials.
"""
with Session(db.engine) as session:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
.first()
)
if not datasource_provider:
raise ValueError("Datasource provider not found")
# update name
if name and name != datasource_provider.name:
if (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
.count()
> 0
):
raise ValueError("Authorization name is already exists")
datasource_provider.name = name
# update credentials
if credentials:
secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
original_credentials = {
key: value if key not in secret_variables else encrypter.decrypt_token(tenant_id, value)
for key, value in datasource_provider.encrypted_credentials.items()
}
new_credentials = {
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
try:
current_user = get_current_user()
self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
provider=provider,
plugin_id=plugin_id,
credentials=new_credentials,
)
except Exception as e:
raise ValueError(f"Failed to validate credentials: {str(e)}")
encrypted_credentials = {}
for key, value in new_credentials.items():
if key in secret_variables:
encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
else:
encrypted_credentials[key] = value
datasource_provider.encrypted_credentials = encrypted_credentials
session.commit()
def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
"""
remove datasource credentials.
:param tenant_id: workspace id
:param provider: provider name
:param plugin_id: plugin id
:return:
"""
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
.first()
)
if datasource_provider:
db.session.delete(datasource_provider)
db.session.commit()

View File

@@ -0,0 +1,83 @@
import logging
from collections.abc import Callable, Sequence
from dataclasses import asdict
from functools import cached_property
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.feature_service import FeatureService
from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task
logger = logging.getLogger(__name__)
class DocumentIndexingTaskProxy:
def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
self._tenant_id = tenant_id
self._dataset_id = dataset_id
self._document_ids = document_ids
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
@cached_property
def features(self):
return FeatureService.get_features(self._tenant_id)
def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
logger.info("send dataset %s to direct queue", self._dataset_id)
task_func.delay( # type: ignore
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
)
def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
logger.info("send dataset %s to tenant queue", self._dataset_id)
if self._tenant_isolated_task_queue.get_task_key():
# Add to waiting queue using List operations (lpush)
self._tenant_isolated_task_queue.push_tasks(
[
asdict(
DocumentTask(
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
)
)
]
)
logger.info("push tasks: %s - %s", self._dataset_id, self._document_ids)
else:
# Set flag and execute task
self._tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
)
logger.info("init tasks: %s - %s", self._dataset_id, self._document_ids)
def _send_to_default_tenant_queue(self):
self._send_to_tenant_queue(normal_document_indexing_task)
def _send_to_priority_tenant_queue(self):
self._send_to_tenant_queue(priority_document_indexing_task)
def _send_to_priority_direct_queue(self):
self._send_to_direct_queue(priority_document_indexing_task)
def _dispatch(self):
logger.info(
"dispatch args: %s - %s - %s",
self._tenant_id,
self.features.billing.enabled,
self.features.billing.subscription.plan,
)
# dispatch to different indexing queue with tenant isolation when billing enabled
if self.features.billing.enabled:
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
# dispatch to normal pipeline queue with tenant self sub queue for sandbox plan
self._send_to_default_tenant_queue()
else:
# dispatch to priority pipeline queue with tenant self sub queue for other plans
self._send_to_priority_tenant_queue()
else:
# dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise
self._send_to_priority_direct_queue()
def delay(self):
self._dispatch()

View File

@@ -0,0 +1,163 @@
import logging
from collections.abc import Mapping
from sqlalchemy import case
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from models.model import App, DefaultEndUserSessionID, EndUser
logger = logging.getLogger(__name__)
class EndUserService:
"""
Service for managing end users.
"""
@classmethod
def get_or_create_end_user(cls, app_model: App, user_id: str | None = None) -> EndUser:
"""
Get or create an end user for a given app.
"""
return cls.get_or_create_end_user_by_type(InvokeFrom.SERVICE_API, app_model.tenant_id, app_model.id, user_id)
@classmethod
def get_or_create_end_user_by_type(
cls, type: InvokeFrom, tenant_id: str, app_id: str, user_id: str | None = None
) -> EndUser:
"""
Get or create an end user for a given app and type.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
with Session(db.engine, expire_on_commit=False) as session:
# Query with ORDER BY to prioritize exact type matches while maintaining backward compatibility
# This single query approach is more efficient than separate queries
end_user = (
session.query(EndUser)
.where(
EndUser.tenant_id == tenant_id,
EndUser.app_id == app_id,
EndUser.session_id == user_id,
)
.order_by(
# Prioritize records with matching type (0 = match, 1 = no match)
case((EndUser.type == type, 0), else_=1)
)
.first()
)
if end_user:
# If found a legacy end user with different type, update it for future consistency
if end_user.type != type:
logger.info(
"Upgrading legacy EndUser %s from type=%s to %s for session_id=%s",
end_user.id,
end_user.type,
type,
user_id,
)
end_user.type = type
session.commit()
else:
# Create new end user if none exists
end_user = EndUser(
tenant_id=tenant_id,
app_id=app_id,
type=type,
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID,
session_id=user_id,
external_user_id=user_id,
)
session.add(end_user)
session.commit()
return end_user
@classmethod
def create_end_user_batch(
cls, type: InvokeFrom, tenant_id: str, app_ids: list[str], user_id: str
) -> Mapping[str, EndUser]:
"""Create end users in batch.
Creates end users in batch for the specified tenant and application IDs in O(1) time.
This batch creation is necessary because trigger subscriptions can span multiple applications,
and trigger events may be dispatched to multiple applications simultaneously.
For each app_id in app_ids, check if an `EndUser` with the given
`user_id` (as session_id/external_user_id) already exists for the
tenant/app and type `type`. If it exists, return it; otherwise,
create it. Operates with minimal DB I/O by querying and inserting in
batches.
Returns a mapping of `app_id -> EndUser`.
"""
# Normalize user_id to default if empty
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Deduplicate app_ids while preserving input order
seen: set[str] = set()
unique_app_ids: list[str] = []
for app_id in app_ids:
if app_id not in seen:
seen.add(app_id)
unique_app_ids.append(app_id)
# Result is a simple app_id -> EndUser mapping
result: dict[str, EndUser] = {}
if not unique_app_ids:
return result
with Session(db.engine, expire_on_commit=False) as session:
# Fetch existing end users for all target apps in a single query
existing_end_users: list[EndUser] = (
session.query(EndUser)
.where(
EndUser.tenant_id == tenant_id,
EndUser.app_id.in_(unique_app_ids),
EndUser.session_id == user_id,
EndUser.type == type,
)
.all()
)
found_app_ids: set[str] = set()
for eu in existing_end_users:
# If duplicates exist due to weak DB constraints, prefer the first
if eu.app_id not in result:
result[eu.app_id] = eu
found_app_ids.add(eu.app_id)
# Determine which apps still need an EndUser created
missing_app_ids = [app_id for app_id in unique_app_ids if app_id not in found_app_ids]
if missing_app_ids:
new_end_users: list[EndUser] = []
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
for app_id in missing_app_ids:
new_end_users.append(
EndUser(
tenant_id=tenant_id,
app_id=app_id,
type=type,
is_anonymous=is_anonymous,
session_id=user_id,
external_user_id=user_id,
)
)
session.add_all(new_end_users)
session.commit()
for eu in new_end_users:
result[eu.app_id] = eu
return result

View File

View File

@@ -0,0 +1,55 @@
import os
from collections.abc import Mapping
from typing import Any
import httpx
class BaseRequest:
proxies: Mapping[str, str] | None = {
"http": "",
"https": "",
}
base_url = ""
secret_key = ""
secret_key_header = ""
@classmethod
def _build_mounts(cls) -> dict[str, httpx.BaseTransport] | None:
if not cls.proxies:
return None
mounts: dict[str, httpx.BaseTransport] = {}
for scheme, value in cls.proxies.items():
if not value:
continue
key = f"{scheme}://" if not scheme.endswith("://") else scheme
mounts[key] = httpx.HTTPTransport(proxy=value)
return mounts or None
@classmethod
def send_request(
cls,
method: str,
endpoint: str,
json: Any | None = None,
params: Mapping[str, Any] | None = None,
) -> Any:
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
mounts = cls._build_mounts()
with httpx.Client(mounts=mounts) as client:
response = client.request(method, url, json=json, params=params, headers=headers)
return response.json()
class EnterpriseRequest(BaseRequest):
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
secret_key_header = "Enterprise-Api-Secret-Key"
class EnterprisePluginManagerRequest(BaseRequest):
base_url = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_URL", "ENTERPRISE_PLUGIN_MANAGER_API_URL")
secret_key = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY", "ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY")
secret_key_header = "Plugin-Manager-Inner-Api-Secret-Key"

View File

@@ -0,0 +1,114 @@
from datetime import datetime
from pydantic import BaseModel, Field
from services.enterprise.base import EnterpriseRequest
class WebAppSettings(BaseModel):
access_mode: str = Field(
description="Access mode for the web app. Can be 'public', 'private', 'private_all', 'sso_verified'",
default="private",
alias="accessMode",
)
class EnterpriseService:
@classmethod
def get_info(cls):
return EnterpriseRequest.send_request("GET", "/info")
@classmethod
def get_workspace_info(cls, tenant_id: str):
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
@classmethod
def get_app_sso_settings_last_update_time(cls) -> datetime:
data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time")
if not data:
raise ValueError("No data found.")
try:
# parse the UTC timestamp from the response
return datetime.fromisoformat(data)
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
@classmethod
def get_workspace_sso_settings_last_update_time(cls) -> datetime:
data = EnterpriseRequest.send_request("GET", "/sso/workspace/last-update-time")
if not data:
raise ValueError("No data found.")
try:
# parse the UTC timestamp from the response
return datetime.fromisoformat(data)
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
class WebAppAuth:
@classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str):
params = {"userId": user_id, "appId": app_id}
data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params)
return data.get("result", False)
@classmethod
def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_ids: list[str]):
if not app_ids:
return {}
body = {"userId": user_id, "appIds": app_ids}
data = EnterpriseRequest.send_request("POST", "/webapp/permission/batch", json=body)
if not data:
raise ValueError("No data found.")
return data.get("permissions", {})
@classmethod
def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings:
if not app_id:
raise ValueError("app_id must be provided.")
params = {"appId": app_id}
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params)
if not data:
raise ValueError("No data found.")
return WebAppSettings.model_validate(data)
@classmethod
def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]:
if not app_ids:
return {}
body = {"appIds": app_ids}
data: dict[str, str] = EnterpriseRequest.send_request("POST", "/webapp/access-mode/batch/id", json=body)
if not data:
raise ValueError("No data found.")
if not isinstance(data["accessModes"], dict):
raise ValueError("Invalid data format.")
ret = {}
for key, value in data["accessModes"].items():
curr = WebAppSettings()
curr.access_mode = value
ret[key] = curr
return ret
@classmethod
def update_app_access_mode(cls, app_id: str, access_mode: str):
if not app_id:
raise ValueError("app_id must be provided.")
if access_mode not in ["public", "private", "private_all"]:
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'")
data = {"appId": app_id, "accessMode": access_mode}
response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data)
return response.get("result", False)
@classmethod
def cleanup_webapp(cls, app_id: str):
if not app_id:
raise ValueError("app_id must be provided.")
body = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)

View File

@@ -0,0 +1,57 @@
import enum
import logging
from pydantic import BaseModel
from services.enterprise.base import EnterprisePluginManagerRequest
from services.errors.base import BaseServiceError
logger = logging.getLogger(__name__)
class PluginCredentialType(enum.Enum):
MODEL = 0 # must be 0 for API contract compatibility
TOOL = 1 # must be 1 for API contract compatibility
def to_number(self):
return self.value
class CheckCredentialPolicyComplianceRequest(BaseModel):
dify_credential_id: str
provider: str
credential_type: PluginCredentialType
def model_dump(self, **kwargs):
data = super().model_dump(**kwargs)
data["credential_type"] = self.credential_type.to_number()
return data
class CredentialPolicyViolationError(BaseServiceError):
pass
class PluginManagerService:
@classmethod
def check_credential_policy_compliance(cls, body: CheckCredentialPolicyComplianceRequest):
try:
ret = EnterprisePluginManagerRequest.send_request(
"POST", "/check-credential-policy-compliance", json=body.model_dump()
)
if not isinstance(ret, dict) or "result" not in ret:
raise ValueError("Invalid response format from plugin manager API")
except Exception as e:
raise CredentialPolicyViolationError(
f"error occurred while checking credential policy compliance: {e}"
) from e
if not ret.get("result", False):
raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials")
logging.debug(
"Credential policy compliance checked for %s with credential %s, result: %s",
body.provider,
body.dify_credential_id,
ret.get("result", False),
)

View File

View File

@@ -0,0 +1,26 @@
from typing import Literal, Union
from pydantic import BaseModel
class AuthorizationConfig(BaseModel):
type: Literal[None, "basic", "bearer", "custom"]
api_key: Union[None, str] = None
header: Union[None, str] = None
class Authorization(BaseModel):
type: Literal["no-auth", "api-key"]
config: AuthorizationConfig | None = None
class ProcessStatusSetting(BaseModel):
request_method: str
url: str
class ExternalKnowledgeApiSetting(BaseModel):
url: str
request_method: str
headers: dict | None = None
params: dict | None = None

View File

@@ -0,0 +1,169 @@
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel
from core.rag.retrieval.retrieval_methods import RetrievalMethod
class ParentMode(StrEnum):
FULL_DOC = "full-doc"
PARAGRAPH = "paragraph"
class NotionIcon(BaseModel):
type: str
url: str | None = None
emoji: str | None = None
class NotionPage(BaseModel):
page_id: str
page_name: str
page_icon: NotionIcon | None = None
type: str
class NotionInfo(BaseModel):
credential_id: str
workspace_id: str
pages: list[NotionPage]
class WebsiteInfo(BaseModel):
provider: str
job_id: str
urls: list[str]
only_main_content: bool = True
class FileInfo(BaseModel):
file_ids: list[str]
class InfoList(BaseModel):
data_source_type: Literal["upload_file", "notion_import", "website_crawl"]
notion_info_list: list[NotionInfo] | None = None
file_info_list: FileInfo | None = None
website_info_list: WebsiteInfo | None = None
class DataSource(BaseModel):
info_list: InfoList
class PreProcessingRule(BaseModel):
id: str
enabled: bool
class Segmentation(BaseModel):
separator: str = "\n"
max_tokens: int
chunk_overlap: int = 0
class Rule(BaseModel):
pre_processing_rules: list[PreProcessingRule] | None = None
segmentation: Segmentation | None = None
parent_mode: Literal["full-doc", "paragraph"] | None = None
subchunk_segmentation: Segmentation | None = None
class ProcessRule(BaseModel):
mode: Literal["automatic", "custom", "hierarchical"]
rules: Rule | None = None
class RerankingModel(BaseModel):
reranking_provider_name: str | None = None
reranking_model_name: str | None = None
class WeightVectorSetting(BaseModel):
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class WeightKeywordSetting(BaseModel):
keyword_weight: float
class WeightModel(BaseModel):
weight_type: Literal["semantic_first", "keyword_first", "customized"] | None = None
vector_setting: WeightVectorSetting | None = None
keyword_setting: WeightKeywordSetting | None = None
class RetrievalModel(BaseModel):
search_method: RetrievalMethod
reranking_enable: bool
reranking_model: RerankingModel | None = None
reranking_mode: str | None = None
top_k: int
score_threshold_enabled: bool
score_threshold: float | None = None
weights: WeightModel | None = None
class MetaDataConfig(BaseModel):
doc_type: str
doc_metadata: dict
class KnowledgeConfig(BaseModel):
original_document_id: str | None = None
duplicate: bool = True
indexing_technique: Literal["high_quality", "economy"]
data_source: DataSource | None = None
process_rule: ProcessRule | None = None
retrieval_model: RetrievalModel | None = None
doc_form: str = "text_model"
doc_language: str = "English"
embedding_model: str | None = None
embedding_model_provider: str | None = None
name: str | None = None
class SegmentUpdateArgs(BaseModel):
content: str | None = None
answer: str | None = None
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
enabled: bool | None = None
class ChildChunkUpdateArgs(BaseModel):
id: str | None = None
content: str
class MetadataArgs(BaseModel):
type: Literal["string", "number", "time"]
name: str
class MetadataUpdateArgs(BaseModel):
name: str
value: str | int | float | None = None
class MetadataDetail(BaseModel):
id: str
name: str
value: str | int | float | None = None
class DocumentMetadataOperation(BaseModel):
document_id: str
metadata_list: list[MetadataDetail]
partial_update: bool = False
class MetadataOperationData(BaseModel):
"""
Metadata operation data
"""
operation_data: list[DocumentMetadataOperation]

View File

@@ -0,0 +1,132 @@
from typing import Literal
from pydantic import BaseModel, field_validator
from core.rag.retrieval.retrieval_methods import RetrievalMethod
class IconInfo(BaseModel):
icon: str
icon_background: str | None = None
icon_type: str | None = None
icon_url: str | None = None
class PipelineTemplateInfoEntity(BaseModel):
name: str
description: str
icon_info: IconInfo
class RagPipelineDatasetCreateEntity(BaseModel):
name: str
description: str
icon_info: IconInfo
permission: str
partial_member_list: list[str] | None = None
yaml_content: str | None = None
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
reranking_provider_name: str | None = ""
reranking_model_name: str | None = ""
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting | None
keyword_setting: KeywordSetting | None
class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str
class EconomySetting(BaseModel):
"""
Economy Setting.
"""
keyword_number: int
class RetrievalSetting(BaseModel):
"""
Retrieval Setting.
"""
search_method: RetrievalMethod
top_k: int
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False
reranking_mode: str | None = "reranking_model"
reranking_enable: bool | None = True
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting
class KnowledgeConfiguration(BaseModel):
"""
Knowledge Base Configuration.
"""
chunk_structure: str
indexing_technique: Literal["high_quality", "economy"]
embedding_model_provider: str = ""
embedding_model: str = ""
keyword_number: int | None = 10
retrieval_model: RetrievalSetting
@field_validator("embedding_model_provider", mode="before")
@classmethod
def validate_embedding_model_provider(cls, v):
if v is None:
return ""
return v
@field_validator("embedding_model", mode="before")
@classmethod
def validate_embedding_model(cls, v):
if v is None:
return ""
return v

View File

@@ -0,0 +1,180 @@
from collections.abc import Sequence
from enum import StrEnum
from pydantic import BaseModel, ConfigDict, model_validator
from configs import dify_config
from core.entities.model_entities import (
ModelWithProviderEntity,
ProviderModelWithStatusEntity,
)
from core.entities.provider_entities import (
CredentialConfiguration,
CustomModelConfiguration,
ProviderQuotaType,
QuotaConfiguration,
UnaddedModelConfiguration,
)
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
ModelCredentialSchema,
ProviderCredentialSchema,
ProviderHelpEntity,
SimpleProviderEntity,
)
from models.provider import ProviderType
class CustomConfigurationStatus(StrEnum):
"""
Enum class for custom configuration status.
"""
ACTIVE = "active"
NO_CONFIGURE = "no-configure"
class CustomConfigurationResponse(BaseModel):
"""
Model class for provider custom configuration response.
"""
status: CustomConfigurationStatus
current_credential_id: str | None = None
current_credential_name: str | None = None
available_credentials: list[CredentialConfiguration] | None = None
custom_models: list[CustomModelConfiguration] | None = None
can_added_models: list[UnaddedModelConfiguration] | None = None
class SystemConfigurationResponse(BaseModel):
"""
Model class for provider system configuration response.
"""
enabled: bool
current_quota_type: ProviderQuotaType | None = None
quota_configurations: list[QuotaConfiguration] = []
class ProviderResponse(BaseModel):
"""
Model class for provider response.
"""
tenant_id: str
provider: str
label: I18nObject
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
supported_model_types: Sequence[ModelType]
configurate_methods: list[ConfigurateMethod]
provider_credential_schema: ProviderCredentialSchema | None = None
model_credential_schema: ModelCredentialSchema | None = None
preferred_provider_type: ProviderType
custom_configuration: CustomConfigurationResponse
system_configuration: SystemConfigurationResponse
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@model_validator(mode="after")
def _(self):
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
if self.icon_small is not None:
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
return self
class ProviderWithModelsResponse(BaseModel):
"""
Model class for provider with models response.
"""
tenant_id: str
provider: str
label: I18nObject
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
status: CustomConfigurationStatus
models: list[ProviderModelWithStatusEntity]
@model_validator(mode="after")
def _(self):
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
if self.icon_small is not None:
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
return self
class SimpleProviderEntityResponse(SimpleProviderEntity):
"""
Simple provider entity response.
"""
tenant_id: str
@model_validator(mode="after")
def _(self):
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
if self.icon_small is not None:
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
return self
class DefaultModelResponse(BaseModel):
"""
Default model entity.
"""
model: str
model_type: ModelType
provider: SimpleProviderEntityResponse
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
"""
Model with provider entity.
"""
provider: SimpleProviderEntityResponse
def __init__(self, tenant_id: str, model: ModelWithProviderEntity):
dump_model = model.model_dump()
dump_model["provider"]["tenant_id"] = tenant_id
super().__init__(**dump_model)

View File

@@ -0,0 +1,27 @@
from . import (
account,
app,
app_model_config,
audio,
base,
conversation,
dataset,
document,
file,
index,
message,
)
__all__ = [
"account",
"app",
"app_model_config",
"audio",
"base",
"conversation",
"dataset",
"document",
"file",
"index",
"message",
]

View File

@@ -0,0 +1,57 @@
from services.errors.base import BaseServiceError
class AccountNotFoundError(BaseServiceError):
pass
class AccountRegisterError(BaseServiceError):
pass
class AccountLoginError(BaseServiceError):
pass
class AccountPasswordError(BaseServiceError):
pass
class AccountNotLinkTenantError(BaseServiceError):
pass
class CurrentPasswordIncorrectError(BaseServiceError):
pass
class LinkAccountIntegrateError(BaseServiceError):
pass
class TenantNotFoundError(BaseServiceError):
pass
class AccountAlreadyInTenantError(BaseServiceError):
pass
class InvalidActionError(BaseServiceError):
pass
class CannotOperateSelfError(BaseServiceError):
pass
class NoPermissionError(BaseServiceError):
pass
class MemberNotInTenantError(BaseServiceError):
pass
class RoleAlreadyAssignedError(BaseServiceError):
pass

View File

@@ -0,0 +1,46 @@
class MoreLikeThisDisabledError(Exception):
pass
class WorkflowHashNotEqualError(Exception):
pass
class IsDraftWorkflowError(Exception):
pass
class WorkflowNotFoundError(Exception):
pass
class WorkflowIdFormatError(Exception):
pass
class InvokeRateLimitError(Exception):
"""Raised when rate limit is exceeded for workflow invocations."""
pass
class QuotaExceededError(ValueError):
"""Raised when billing quota is exceeded for a feature."""
def __init__(self, feature: str, tenant_id: str, required: int):
self.feature = feature
self.tenant_id = tenant_id
self.required = required
super().__init__(f"Quota exceeded for feature '{feature}' (tenant: {tenant_id}). Required: {required}")
class TriggerNodeLimitExceededError(ValueError):
"""Raised when trigger node count exceeds the plan limit."""
def __init__(self, count: int, limit: int):
self.count = count
self.limit = limit
super().__init__(
f"Trigger node count ({count}) exceeds the limit ({limit}) for your subscription plan. "
f"Please upgrade your plan or reduce the number of trigger nodes."
)

View File

@@ -0,0 +1,9 @@
from services.errors.base import BaseServiceError
class AppModelConfigBrokenError(BaseServiceError):
pass
class ProviderNotFoundError(BaseServiceError):
pass

View File

@@ -0,0 +1,22 @@
class NoAudioUploadedServiceError(Exception):
pass
class AudioTooLargeServiceError(Exception):
pass
class UnsupportedAudioTypeServiceError(Exception):
pass
class ProviderNotSupportSpeechToTextServiceError(Exception):
pass
class ProviderNotSupportTextToSpeechServiceError(Exception):
pass
class ProviderNotSupportTextToSpeechLanageServiceError(Exception):
pass

View File

@@ -0,0 +1,3 @@
class BaseServiceError(ValueError):
def __init__(self, description: str | None = None):
self.description = description

View File

@@ -0,0 +1,9 @@
from services.errors.base import BaseServiceError
class ChildChunkIndexingError(BaseServiceError):
description = "{message}"
class ChildChunkDeleteIndexError(BaseServiceError):
description = "{message}"

View File

@@ -0,0 +1,21 @@
from services.errors.base import BaseServiceError
class LastConversationNotExistsError(BaseServiceError):
pass
class ConversationNotExistsError(BaseServiceError):
pass
class ConversationCompletedError(Exception):
pass
class ConversationVariableNotExistsError(BaseServiceError):
pass
class ConversationVariableTypeMismatchError(BaseServiceError):
pass

View File

@@ -0,0 +1,9 @@
from services.errors.base import BaseServiceError
class DatasetNameDuplicateError(BaseServiceError):
pass
class DatasetInUseError(BaseServiceError):
pass

View File

@@ -0,0 +1,5 @@
from services.errors.base import BaseServiceError
class DocumentIndexingError(BaseServiceError):
pass

View File

@@ -0,0 +1,17 @@
from services.errors.base import BaseServiceError
class FileNotExistsError(BaseServiceError):
pass
class FileTooLargeError(BaseServiceError):
description = "{message}"
class UnsupportedFileTypeError(BaseServiceError):
pass
class BlockedFileExtensionError(BaseServiceError):
description = "File extension '{extension}' is not allowed for security reasons"

View File

@@ -0,0 +1,5 @@
from services.errors.base import BaseServiceError
class IndexNotInitializedError(BaseServiceError):
pass

View File

@@ -0,0 +1,16 @@
class InvokeError(Exception):
"""Base class for all LLM exceptions."""
description: str | None = None
def __init__(self, description: str | None = None):
self.description = description
def __str__(self):
return self.description or self.__class__.__name__
class InvokeRateLimitError(InvokeError):
"""Raised when the Invoke returns rate limit error."""
description = "Rate Limit Error"

View File

@@ -0,0 +1,17 @@
from services.errors.base import BaseServiceError
class FirstMessageNotExistsError(BaseServiceError):
pass
class LastMessageNotExistsError(BaseServiceError):
pass
class MessageNotExistsError(BaseServiceError):
pass
class SuggestedQuestionsAfterAnswerDisabledError(BaseServiceError):
pass

View File

@@ -0,0 +1,5 @@
from services.errors.base import BaseServiceError
class PluginInstallationForbiddenError(BaseServiceError):
pass

View File

@@ -0,0 +1,10 @@
class WorkflowInUseError(ValueError):
"""Raised when attempting to delete a workflow that's in use by an app"""
pass
class DraftWorkflowDeletionError(ValueError):
"""Raised when attempting to delete a draft workflow"""
pass

View File

@@ -0,0 +1,13 @@
from services.errors.base import BaseServiceError
class WorkSpaceNotAllowedCreateError(BaseServiceError):
pass
class WorkSpaceNotFoundError(BaseServiceError):
pass
class WorkspacesLimitExceededError(BaseServiceError):
pass

View File

@@ -0,0 +1,327 @@
import json
from copy import deepcopy
from typing import Any, Union, cast
from urllib.parse import urlparse
import httpx
from sqlalchemy import select
from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
from core.rag.entities.metadata_entities import MetadataCondition
from core.workflow.nodes.http_request.exc import InvalidHttpMethodError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import (
Dataset,
ExternalKnowledgeApis,
ExternalKnowledgeBindings,
)
from services.entities.external_knowledge_entities.external_knowledge_entities import (
Authorization,
ExternalKnowledgeApiSetting,
)
from services.errors.dataset import DatasetNameDuplicateError
class ExternalDatasetService:
@staticmethod
def get_external_knowledge_apis(
page, per_page, tenant_id, search=None
) -> tuple[list[ExternalKnowledgeApis], int | None]:
query = (
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.tenant_id == tenant_id)
.order_by(ExternalKnowledgeApis.created_at.desc())
)
if search:
query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
external_knowledge_apis = db.paginate(
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
)
return external_knowledge_apis.items, external_knowledge_apis.total
@classmethod
def validate_api_list(cls, api_settings: dict):
if not api_settings:
raise ValueError("api list is empty")
if not api_settings.get("endpoint"):
raise ValueError("endpoint is required")
if not api_settings.get("api_key"):
raise ValueError("api_key is required")
@staticmethod
def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis:
settings = args.get("settings")
if settings is None:
raise ValueError("settings is required")
ExternalDatasetService.check_endpoint_and_api_key(settings)
external_knowledge_api = ExternalKnowledgeApis(
tenant_id=tenant_id,
created_by=user_id,
updated_by=user_id,
name=str(args.get("name")),
description=args.get("description", ""),
settings=json.dumps(args.get("settings"), ensure_ascii=False),
)
db.session.add(external_knowledge_api)
db.session.commit()
return external_knowledge_api
@staticmethod
def check_endpoint_and_api_key(settings: dict):
if "endpoint" not in settings or not settings["endpoint"]:
raise ValueError("endpoint is required")
if "api_key" not in settings or not settings["api_key"]:
raise ValueError("api_key is required")
endpoint = f"{settings['endpoint']}/retrieval"
api_key = settings["api_key"]
parsed_url = urlparse(endpoint)
if not all([parsed_url.scheme, parsed_url.netloc]):
if not endpoint.startswith("http://") and not endpoint.startswith("https://"):
raise ValueError(f"invalid endpoint: {endpoint} must start with http:// or https://")
else:
raise ValueError(f"invalid endpoint: {endpoint}")
try:
response = ssrf_proxy.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
except Exception as e:
raise ValueError(f"failed to connect to the endpoint: {endpoint}") from e
if response.status_code == 502:
raise ValueError(f"Bad Gateway: failed to connect to the endpoint: {endpoint}")
if response.status_code == 404:
raise ValueError(f"Not Found: failed to connect to the endpoint: {endpoint}")
if response.status_code == 403:
raise ValueError(f"Forbidden: Authorization failed with api_key: {api_key}")
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
return external_knowledge_api
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
settings = args.get("settings")
if settings and settings.get("api_key") == HIDDEN_VALUE and external_knowledge_api.settings_dict:
settings["api_key"] = external_knowledge_api.settings_dict.get("api_key")
external_knowledge_api.name = args.get("name")
external_knowledge_api.description = args.get("description", "")
external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False)
external_knowledge_api.updated_by = user_id
external_knowledge_api.updated_at = naive_utc_now()
db.session.commit()
return external_knowledge_api
@staticmethod
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
db.session.delete(external_knowledge_api)
db.session.commit()
@staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
count = (
db.session.query(ExternalKnowledgeBindings)
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
.count()
)
if count > 0:
return True, count
return False, 0
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: ExternalKnowledgeBindings | None = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
return external_knowledge_binding
@staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None or external_knowledge_api.settings is None:
raise ValueError("api template not found")
settings = json.loads(external_knowledge_api.settings)
for setting in settings:
custom_parameters = setting.get("document_process_setting")
if custom_parameters:
for parameter in custom_parameters:
if parameter.get("required", False) and not process_parameter.get(parameter.get("name")):
raise ValueError(f"{parameter.get('name')} is required")
@staticmethod
def process_external_api(
settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]]
) -> httpx.Response:
"""
do http request depending on api bundle
"""
kwargs: dict[str, Any] = {
"url": settings.url,
"headers": settings.headers,
"follow_redirects": True,
}
_METHOD_MAP = {
"get": ssrf_proxy.get,
"head": ssrf_proxy.head,
"post": ssrf_proxy.post,
"put": ssrf_proxy.put,
"delete": ssrf_proxy.delete,
"patch": ssrf_proxy.patch,
}
method_lc = settings.request_method.lower()
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {settings.request_method}")
response: httpx.Response = _METHOD_MAP[method_lc](data=json.dumps(settings.params), files=files, **kwargs)
return response
@staticmethod
def assembling_headers(authorization: Authorization, headers: dict | None = None) -> dict[str, Any]:
authorization = deepcopy(authorization)
if headers:
headers = deepcopy(headers)
else:
headers = {}
if authorization.type == "api-key":
if authorization.config is None:
raise ValueError("authorization config is required")
if authorization.config.api_key is None:
raise ValueError("api_key is required")
if not authorization.config.header:
authorization.config.header = "Authorization"
if authorization.config.type == "bearer":
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
elif authorization.config.type == "basic":
headers[authorization.config.header] = f"Basic {authorization.config.api_key}"
elif authorization.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key
return headers
@staticmethod
def get_external_knowledge_api_settings(settings: dict) -> ExternalKnowledgeApiSetting:
return ExternalKnowledgeApiSetting.model_validate(settings)
@staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists
if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
.first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
dataset = Dataset(
tenant_id=tenant_id,
name=args.get("name"),
description=args.get("description", ""),
provider="external",
retrieval_model=args.get("external_retrieval_model"),
created_by=user_id,
)
db.session.add(dataset)
db.session.flush()
if args.get("external_knowledge_id") is None:
raise ValueError("external_knowledge_id is required")
if args.get("external_knowledge_api_id") is None:
raise ValueError("external_knowledge_api_id is required")
external_knowledge_binding = ExternalKnowledgeBindings(
tenant_id=tenant_id,
dataset_id=dataset.id,
external_knowledge_api_id=args.get("external_knowledge_api_id") or "",
external_knowledge_id=args.get("external_knowledge_id") or "",
created_by=user_id,
)
db.session.add(external_knowledge_binding)
db.session.commit()
return dataset
@staticmethod
def fetch_external_knowledge_retrieval(
tenant_id: str,
dataset_id: str,
query: str,
external_retrieval_parameters: dict,
metadata_condition: MetadataCondition | None = None,
):
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=external_knowledge_binding.external_knowledge_api_id)
.first()
)
if external_knowledge_api is None or external_knowledge_api.settings is None:
raise ValueError("external api template not found")
settings = json.loads(external_knowledge_api.settings)
headers = {"Content-Type": "application/json"}
if settings.get("api_key"):
headers["Authorization"] = f"Bearer {settings.get('api_key')}"
score_threshold_enabled = external_retrieval_parameters.get("score_threshold_enabled") or False
score_threshold = external_retrieval_parameters.get("score_threshold", 0.0) if score_threshold_enabled else 0.0
request_params = {
"retrieval_setting": {
"top_k": external_retrieval_parameters.get("top_k"),
"score_threshold": score_threshold,
},
"query": query,
"knowledge_id": external_knowledge_binding.external_knowledge_id,
"metadata_condition": metadata_condition.model_dump() if metadata_condition else None,
}
response = ExternalDatasetService.process_external_api(
ExternalKnowledgeApiSetting(
url=f"{settings.get('endpoint')}/retrieval",
request_method="post",
headers=headers,
params=request_params,
),
None,
)
if response.status_code == 200:
return cast(list[Any], response.json().get("records", []))
return []

View File

@@ -0,0 +1,363 @@
from enum import StrEnum
from pydantic import BaseModel, ConfigDict, Field
from configs import dify_config
from enums.cloud_plan import CloudPlan
from services.billing_service import BillingService
from services.enterprise.enterprise_service import EnterpriseService
class SubscriptionModel(BaseModel):
plan: str = CloudPlan.SANDBOX
interval: str = ""
class BillingModel(BaseModel):
enabled: bool = False
subscription: SubscriptionModel = SubscriptionModel()
class EducationModel(BaseModel):
enabled: bool = False
activated: bool = False
class LimitationModel(BaseModel):
size: int = 0
limit: int = 0
class LicenseLimitationModel(BaseModel):
"""
- enabled: whether this limit is enforced
- size: current usage count
- limit: maximum allowed count; 0 means unlimited
"""
enabled: bool = Field(False, description="Whether this limit is currently active")
size: int = Field(0, description="Number of resources already consumed")
limit: int = Field(0, description="Maximum number of resources allowed; 0 means no limit")
def is_available(self, required: int = 1) -> bool:
"""
Determine whether the requested amount can be allocated.
Returns True if:
- this limit is not active, or
- the limit is zero (unlimited), or
- there is enough remaining quota.
"""
if not self.enabled or self.limit == 0:
return True
return (self.limit - self.size) >= required
class Quota(BaseModel):
usage: int = 0
limit: int = 0
reset_date: int = -1
class LicenseStatus(StrEnum):
NONE = "none"
INACTIVE = "inactive"
ACTIVE = "active"
EXPIRING = "expiring"
EXPIRED = "expired"
LOST = "lost"
class LicenseModel(BaseModel):
status: LicenseStatus = LicenseStatus.NONE
expired_at: str = ""
workspaces: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0)
class BrandingModel(BaseModel):
enabled: bool = False
application_title: str = ""
login_page_logo: str = ""
workspace_logo: str = ""
favicon: str = ""
class WebAppAuthSSOModel(BaseModel):
protocol: str = ""
class WebAppAuthModel(BaseModel):
enabled: bool = False
allow_sso: bool = False
sso_config: WebAppAuthSSOModel = WebAppAuthSSOModel()
allow_email_code_login: bool = False
allow_email_password_login: bool = False
class KnowledgePipeline(BaseModel):
publish_enabled: bool = False
class PluginInstallationScope(StrEnum):
NONE = "none"
OFFICIAL_ONLY = "official_only"
OFFICIAL_AND_SPECIFIC_PARTNERS = "official_and_specific_partners"
ALL = "all"
class PluginInstallationPermissionModel(BaseModel):
# Plugin installation scope possible values:
# none: prohibit all plugin installations
# official_only: allow only Dify official plugins
# official_and_specific_partners: allow official and specific partner plugins
# all: allow installation of all plugins
plugin_installation_scope: PluginInstallationScope = PluginInstallationScope.ALL
# If True, restrict plugin installation to the marketplace only
# Equivalent to ForceEnablePluginVerification
restrict_to_marketplace_only: bool = False
class FeatureModel(BaseModel):
billing: BillingModel = BillingModel()
education: EducationModel = EducationModel()
members: LimitationModel = LimitationModel(size=0, limit=1)
apps: LimitationModel = LimitationModel(size=0, limit=10)
vector_space: LimitationModel = LimitationModel(size=0, limit=5)
knowledge_rate_limit: int = 10
annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
docs_processing: str = "standard"
can_replace_logo: bool = False
model_load_balancing_enabled: bool = False
dataset_operator_enabled: bool = False
webapp_copyright_enabled: bool = False
workspace_members: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0)
is_allow_transfer_workspace: bool = True
trigger_event: Quota = Quota(usage=0, limit=3000, reset_date=0)
api_rate_limit: Quota = Quota(usage=0, limit=5000, reset_date=0)
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
knowledge_pipeline: KnowledgePipeline = KnowledgePipeline()
class KnowledgeRateLimitModel(BaseModel):
enabled: bool = False
limit: int = 10
subscription_plan: str = ""
class PluginManagerModel(BaseModel):
enabled: bool = False
class SystemFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = ""
enable_marketplace: bool = False
max_plugin_package_size: int = dify_config.PLUGIN_MAX_PACKAGE_SIZE
enable_email_code_login: bool = False
enable_email_password_login: bool = True
enable_social_oauth_login: bool = False
is_allow_register: bool = False
is_allow_create_workspace: bool = False
is_email_setup: bool = False
license: LicenseModel = LicenseModel()
branding: BrandingModel = BrandingModel()
webapp_auth: WebAppAuthModel = WebAppAuthModel()
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
class FeatureService:
@classmethod
def get_features(cls, tenant_id: str) -> FeatureModel:
features = FeatureModel()
cls._fulfill_params_from_env(features)
if dify_config.BILLING_ENABLED and tenant_id:
cls._fulfill_params_from_billing_api(features, tenant_id)
if dify_config.ENTERPRISE_ENABLED:
features.webapp_copyright_enabled = True
features.knowledge_pipeline.publish_enabled = True
cls._fulfill_params_from_workspace_info(features, tenant_id)
return features
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
knowledge_rate_limit = KnowledgeRateLimitModel()
if dify_config.BILLING_ENABLED and tenant_id:
knowledge_rate_limit.enabled = True
limit_info = BillingService.get_knowledge_rate_limit(tenant_id)
knowledge_rate_limit.limit = limit_info.get("limit", 10)
knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", CloudPlan.SANDBOX)
return knowledge_rate_limit
@classmethod
def get_system_features(cls) -> SystemFeatureModel:
system_features = SystemFeatureModel()
cls._fulfill_system_params_from_env(system_features)
if dify_config.ENTERPRISE_ENABLED:
system_features.branding.enabled = True
system_features.webapp_auth.enabled = True
system_features.enable_change_email = False
system_features.plugin_manager.enabled = True
cls._fulfill_params_from_enterprise(system_features)
if dify_config.MARKETPLACE_ENABLED:
system_features.enable_marketplace = True
return system_features
@classmethod
def _fulfill_system_params_from_env(cls, system_features: SystemFeatureModel):
system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN
system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN
system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN
system_features.is_allow_register = dify_config.ALLOW_REGISTER
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
@classmethod
def _fulfill_params_from_env(cls, features: FeatureModel):
features.can_replace_logo = dify_config.CAN_REPLACE_LOGO
features.model_load_balancing_enabled = dify_config.MODEL_LB_ENABLED
features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED
features.education.enabled = dify_config.EDUCATION_ENABLED
@classmethod
def _fulfill_params_from_workspace_info(cls, features: FeatureModel, tenant_id: str):
workspace_info = EnterpriseService.get_workspace_info(tenant_id)
if "WorkspaceMembers" in workspace_info:
features.workspace_members.size = workspace_info["WorkspaceMembers"]["used"]
features.workspace_members.limit = workspace_info["WorkspaceMembers"]["limit"]
features.workspace_members.enabled = workspace_info["WorkspaceMembers"]["enabled"]
@classmethod
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
billing_info = BillingService.get_info(tenant_id)
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
features.billing.enabled = billing_info["enabled"]
features.billing.subscription.plan = billing_info["subscription"]["plan"]
features.billing.subscription.interval = billing_info["subscription"]["interval"]
features.education.activated = billing_info["subscription"].get("education", False)
if features.billing.subscription.plan != CloudPlan.SANDBOX:
features.webapp_copyright_enabled = True
else:
features.is_allow_transfer_workspace = False
if "trigger_event" in features_usage_info:
features.trigger_event.usage = features_usage_info["trigger_event"]["usage"]
features.trigger_event.limit = features_usage_info["trigger_event"]["limit"]
features.trigger_event.reset_date = features_usage_info["trigger_event"].get("reset_date", -1)
if "api_rate_limit" in features_usage_info:
features.api_rate_limit.usage = features_usage_info["api_rate_limit"]["usage"]
features.api_rate_limit.limit = features_usage_info["api_rate_limit"]["limit"]
features.api_rate_limit.reset_date = features_usage_info["api_rate_limit"].get("reset_date", -1)
if "members" in billing_info:
features.members.size = billing_info["members"]["size"]
features.members.limit = billing_info["members"]["limit"]
if "apps" in billing_info:
features.apps.size = billing_info["apps"]["size"]
features.apps.limit = billing_info["apps"]["limit"]
if "vector_space" in billing_info:
features.vector_space.size = billing_info["vector_space"]["size"]
features.vector_space.limit = billing_info["vector_space"]["limit"]
if "documents_upload_quota" in billing_info:
features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"]
features.documents_upload_quota.limit = billing_info["documents_upload_quota"]["limit"]
if "annotation_quota_limit" in billing_info:
features.annotation_quota_limit.size = billing_info["annotation_quota_limit"]["size"]
features.annotation_quota_limit.limit = billing_info["annotation_quota_limit"]["limit"]
if "docs_processing" in billing_info:
features.docs_processing = billing_info["docs_processing"]
if "can_replace_logo" in billing_info:
features.can_replace_logo = billing_info["can_replace_logo"]
if "model_load_balancing_enabled" in billing_info:
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
if "knowledge_rate_limit" in billing_info:
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
if "knowledge_pipeline_publish_enabled" in billing_info:
features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"]
@classmethod
def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel):
enterprise_info = EnterpriseService.get_info()
if "SSOEnforcedForSignin" in enterprise_info:
features.sso_enforced_for_signin = enterprise_info["SSOEnforcedForSignin"]
if "SSOEnforcedForSigninProtocol" in enterprise_info:
features.sso_enforced_for_signin_protocol = enterprise_info["SSOEnforcedForSigninProtocol"]
if "EnableEmailCodeLogin" in enterprise_info:
features.enable_email_code_login = enterprise_info["EnableEmailCodeLogin"]
if "EnableEmailPasswordLogin" in enterprise_info:
features.enable_email_password_login = enterprise_info["EnableEmailPasswordLogin"]
if "IsAllowRegister" in enterprise_info:
features.is_allow_register = enterprise_info["IsAllowRegister"]
if "IsAllowCreateWorkspace" in enterprise_info:
features.is_allow_create_workspace = enterprise_info["IsAllowCreateWorkspace"]
if "Branding" in enterprise_info:
features.branding.application_title = enterprise_info["Branding"].get("applicationTitle", "")
features.branding.login_page_logo = enterprise_info["Branding"].get("loginPageLogo", "")
features.branding.workspace_logo = enterprise_info["Branding"].get("workspaceLogo", "")
features.branding.favicon = enterprise_info["Branding"].get("favicon", "")
if "WebAppAuth" in enterprise_info:
features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSso", False)
features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get(
"allowEmailCodeLogin", False
)
features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get(
"allowEmailPasswordLogin", False
)
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
if "License" in enterprise_info:
license_info = enterprise_info["License"]
if "status" in license_info:
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
if "expiredAt" in license_info:
features.license.expired_at = license_info["expiredAt"]
if "workspaces" in license_info:
features.license.workspaces.enabled = license_info["workspaces"]["enabled"]
features.license.workspaces.limit = license_info["workspaces"]["limit"]
features.license.workspaces.size = license_info["workspaces"]["used"]
if "PluginInstallationPermission" in enterprise_info:
plugin_installation_info = enterprise_info["PluginInstallationPermission"]
features.plugin_installation_permission.plugin_installation_scope = plugin_installation_info[
"pluginInstallationScope"
]
features.plugin_installation_permission.restrict_to_marketplace_only = plugin_installation_info[
"restrictToMarketplaceOnly"
]

View File

@@ -0,0 +1,185 @@
import csv
import io
import json
from datetime import datetime
from flask import Response
from sqlalchemy import or_
from extensions.ext_database import db
from models.model import Account, App, Conversation, Message, MessageFeedback
class FeedbackService:
@staticmethod
def export_feedbacks(
app_id: str,
from_source: str | None = None,
rating: str | None = None,
has_comment: bool | None = None,
start_date: str | None = None,
end_date: str | None = None,
format_type: str = "csv",
):
"""
Export feedback data with message details for analysis
Args:
app_id: Application ID
from_source: Filter by feedback source ('user' or 'admin')
rating: Filter by rating ('like' or 'dislike')
has_comment: Only include feedback with comments
start_date: Start date filter (YYYY-MM-DD)
end_date: End date filter (YYYY-MM-DD)
format_type: Export format ('csv' or 'json')
"""
# Validate format early to avoid hitting DB when unnecessary
fmt = (format_type or "csv").lower()
if fmt not in {"csv", "json"}:
raise ValueError(f"Unsupported format: {format_type}")
# Build base query
query = (
db.session.query(MessageFeedback, Message, Conversation, App, Account)
.join(Message, MessageFeedback.message_id == Message.id)
.join(Conversation, MessageFeedback.conversation_id == Conversation.id)
.join(App, MessageFeedback.app_id == App.id)
.outerjoin(Account, MessageFeedback.from_account_id == Account.id)
.where(MessageFeedback.app_id == app_id)
)
# Apply filters
if from_source:
query = query.filter(MessageFeedback.from_source == from_source)
if rating:
query = query.filter(MessageFeedback.rating == rating)
if has_comment is not None:
if has_comment:
query = query.filter(MessageFeedback.content.isnot(None), MessageFeedback.content != "")
else:
query = query.filter(or_(MessageFeedback.content.is_(None), MessageFeedback.content == ""))
if start_date:
try:
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
query = query.filter(MessageFeedback.created_at >= start_dt)
except ValueError:
raise ValueError(f"Invalid start_date format: {start_date}. Use YYYY-MM-DD")
if end_date:
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
query = query.filter(MessageFeedback.created_at <= end_dt)
except ValueError:
raise ValueError(f"Invalid end_date format: {end_date}. Use YYYY-MM-DD")
# Order by creation date (newest first)
query = query.order_by(MessageFeedback.created_at.desc())
# Execute query
results = query.all()
# Prepare data for export
export_data = []
for feedback, message, conversation, app, account in results:
# Get the user query from the message
user_query = message.query or message.inputs.get("query", "") if message.inputs else ""
# Format the feedback data
feedback_record = {
"feedback_id": str(feedback.id),
"app_name": app.name,
"app_id": str(app.id),
"conversation_id": str(conversation.id),
"conversation_name": conversation.name or "",
"message_id": str(message.id),
"user_query": user_query,
"ai_response": message.answer[:500] + "..."
if len(message.answer) > 500
else message.answer, # Truncate long responses
"feedback_rating": "👍" if feedback.rating == "like" else "👎",
"feedback_rating_raw": feedback.rating,
"feedback_comment": feedback.content or "",
"feedback_source": feedback.from_source,
"feedback_date": feedback.created_at.strftime("%Y-%m-%d %H:%M:%S"),
"message_date": message.created_at.strftime("%Y-%m-%d %H:%M:%S"),
"from_account_name": account.name if account else "",
"from_end_user_id": str(feedback.from_end_user_id) if feedback.from_end_user_id else "",
"has_comment": "Yes" if feedback.content and feedback.content.strip() else "No",
}
export_data.append(feedback_record)
# Export based on format
if fmt == "csv":
return FeedbackService._export_csv(export_data, app_id)
else: # fmt == "json"
return FeedbackService._export_json(export_data, app_id)
@staticmethod
def _export_csv(data, app_id):
"""Export data as CSV"""
if not data:
pass # allow empty CSV with headers only
# Create CSV in memory
output = io.StringIO()
# Define headers
headers = [
"feedback_id",
"app_name",
"app_id",
"conversation_id",
"conversation_name",
"message_id",
"user_query",
"ai_response",
"feedback_rating",
"feedback_rating_raw",
"feedback_comment",
"feedback_source",
"feedback_date",
"message_date",
"from_account_name",
"from_end_user_id",
"has_comment",
]
writer = csv.DictWriter(output, fieldnames=headers)
writer.writeheader()
writer.writerows(data)
# Create response without requiring app context
response = Response(output.getvalue(), mimetype="text/csv; charset=utf-8-sig")
response.headers["Content-Disposition"] = (
f"attachment; filename=dify_feedback_export_{app_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
)
return response
@staticmethod
def _export_json(data, app_id):
"""Export data as JSON"""
response_data = {
"export_info": {
"app_id": app_id,
"export_date": datetime.now().isoformat(),
"total_records": len(data),
"data_source": "dify_feedback_export",
},
"feedback_data": data,
}
# Create response without requiring app context
response = Response(
json.dumps(response_data, ensure_ascii=False, indent=2),
mimetype="application/json; charset=utf-8",
)
response.headers["Content-Disposition"] = (
f"attachment; filename=dify_feedback_export_{app_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
)
return response

View File

@@ -0,0 +1,245 @@
import hashlib
import os
import uuid
from typing import Literal, Union
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import NotFound
from configs import dify_config
from constants import (
AUDIO_EXTENSIONS,
DOCUMENT_EXTENSIONS,
IMAGE_EXTENSIONS,
VIDEO_EXTENSIONS,
)
from core.file import helpers as file_helpers
from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
from models import Account
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile
from .errors.file import BlockedFileExtensionError, FileTooLargeError, UnsupportedFileTypeError
PREVIEW_WORDS_LIMIT = 3000
class FileService:
_session_maker: sessionmaker[Session]
def __init__(self, session_factory: sessionmaker | Engine | None = None):
if isinstance(session_factory, Engine):
self._session_maker = sessionmaker(bind=session_factory)
elif isinstance(session_factory, sessionmaker):
self._session_maker = session_factory
else:
raise AssertionError("must be a sessionmaker or an Engine.")
def upload_file(
self,
*,
filename: str,
content: bytes,
mimetype: str,
user: Union[Account, EndUser],
source: Literal["datasets"] | None = None,
source_url: str = "",
) -> UploadFile:
# get file extension
extension = os.path.splitext(filename)[1].lstrip(".").lower()
# check if filename contains invalid characters
if any(c in filename for c in ["/", "\\", ":", "*", "?", '"', "<", ">", "|"]):
raise ValueError("Filename contains invalid characters")
if len(filename) > 200:
filename = filename.split(".")[0][:200] + "." + extension
# check if extension is in blacklist
if extension and extension in dify_config.UPLOAD_FILE_EXTENSION_BLACKLIST:
raise BlockedFileExtensionError(f"File extension '.{extension}' is not allowed for security reasons")
if source == "datasets" and extension not in DOCUMENT_EXTENSIONS:
raise UnsupportedFileTypeError()
# get file size
file_size = len(content)
# check if the file size is exceeded
if not FileService.is_file_size_within_limit(extension=extension, file_size=file_size):
raise FileTooLargeError
# generate file key
file_uuid = str(uuid.uuid4())
current_tenant_id = extract_tenant_id(user)
file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension
# save file to storage
storage.save(file_key, content)
# save file to db
upload_file = UploadFile(
tenant_id=current_tenant_id or "",
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=filename,
size=file_size,
extension=extension,
mime_type=mimetype,
created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER),
created_by=user.id,
created_at=naive_utc_now(),
used=False,
hash=hashlib.sha3_256(content).hexdigest(),
source_url=source_url,
)
# The `UploadFile` ID is generated within its constructor, so flushing to retrieve the ID is unnecessary.
# We can directly generate the `source_url` here before committing.
if not upload_file.source_url:
upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
with self._session_maker(expire_on_commit=False) as session:
session.add(upload_file)
session.commit()
return upload_file
@staticmethod
def is_file_size_within_limit(*, extension: str, file_size: int) -> bool:
if extension in IMAGE_EXTENSIONS:
file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
elif extension in VIDEO_EXTENSIONS:
file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024
elif extension in AUDIO_EXTENSIONS:
file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024
else:
file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024
return file_size <= file_size_limit
def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile:
if len(text_name) > 200:
text_name = text_name[:200]
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = "upload_files/" + tenant_id + "/" + file_uuid + ".txt"
# save file to storage
storage.save(file_key, text.encode("utf-8"))
# save file to db
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=text_name,
size=len(text),
extension="txt",
mime_type="text/plain",
created_by=user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=user_id,
used_at=naive_utc_now(),
)
with self._session_maker(expire_on_commit=False) as session:
session.add(upload_file)
session.commit()
return upload_file
def get_file_preview(self, file_id: str):
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found")
# extract text from file
extension = upload_file.extension
if extension.lower() not in DOCUMENT_EXTENSIONS:
raise UnsupportedFileTypeError()
text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
text = text[0:PREVIEW_WORDS_LIMIT] if text else ""
return text
def get_image_preview(self, file_id: str, timestamp: str, nonce: str, sign: str):
result = file_helpers.verify_image_signature(
upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign
)
if not result:
raise NotFound("File not found or signature is invalid")
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
# extract text from file
extension = upload_file.extension
if extension.lower() not in IMAGE_EXTENSIONS:
raise UnsupportedFileTypeError()
generator = storage.load(upload_file.key, stream=True)
return generator, upload_file.mime_type
def get_file_generator_by_file_id(self, file_id: str, timestamp: str, nonce: str, sign: str):
result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign)
if not result:
raise NotFound("File not found or signature is invalid")
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
generator = storage.load(upload_file.key, stream=True)
return generator, upload_file
def get_public_image_preview(self, file_id: str):
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
# extract text from file
extension = upload_file.extension
if extension.lower() not in IMAGE_EXTENSIONS:
raise UnsupportedFileTypeError()
generator = storage.load(upload_file.key)
return generator, upload_file.mime_type
def get_file_content(self, file_id: str) -> str:
with self._session_maker(expire_on_commit=False) as session:
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found")
content = storage.load(upload_file.key)
return content.decode("utf-8")
def delete_file(self, file_id: str):
with self._session_maker() as session, session.begin():
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id))
if not upload_file:
return
storage.delete(upload_file.key)
session.delete(upload_file)

View File

@@ -0,0 +1,177 @@
import logging
import time
from typing import Any
from core.app.app_config.entities import ModelConfig
from core.model_runtime.entities import LLMMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models import Account
from models.dataset import Dataset, DatasetQuery
logger = logging.getLogger(__name__)
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,
"score_threshold_enabled": False,
}
class HitTestingService:
@classmethod
def retrieve(
cls,
dataset: Dataset,
query: str,
account: Account,
retrieval_model: Any, # FIXME drop this any
external_retrieval_model: dict,
limit: int = 10,
):
start = time.perf_counter()
# get retrieval model , if the model is not setting , using default
if not retrieval_model:
retrieval_model = dataset.retrieval_model or default_retrieval_model
document_ids_filter = None
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
if metadata_filtering_conditions:
dataset_retrieval = DatasetRetrieval()
from core.app.app_config.entities import MetadataFilteringCondition
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
dataset_ids=[dataset.id],
query=query,
metadata_filtering_mode="manual",
metadata_filtering_conditions=metadata_filtering_conditions,
inputs={},
tenant_id="",
user_id="",
metadata_model_config=ModelConfig(provider="", name="", mode=LLMMode.CHAT, completion_params={}),
)
if metadata_filter_document_ids:
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
if metadata_condition and not document_ids_filter:
return cls.compact_retrieve_response(query, [])
all_documents = RetrievalService.retrieve(
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k", 4),
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter,
)
end = time.perf_counter()
logger.debug("Hit testing retrieve in %s seconds", end - start)
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=query,
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by=account.id,
)
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(query, all_documents)
@classmethod
def external_retrieve(
cls,
dataset: Dataset,
query: str,
account: Account,
external_retrieval_model: dict,
metadata_filtering_conditions: dict,
):
if dataset.provider != "external":
return {
"query": {"content": query},
"records": [],
}
start = time.perf_counter()
all_documents = RetrievalService.external_retrieve(
dataset_id=dataset.id,
query=cls.escape_query_for_search(query),
external_retrieval_model=external_retrieval_model,
metadata_filtering_conditions=metadata_filtering_conditions,
)
end = time.perf_counter()
logger.debug("External knowledge hit testing retrieve in %s seconds", end - start)
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=query,
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by=account.id,
)
db.session.add(dataset_query)
db.session.commit()
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
@classmethod
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]:
records = RetrievalService.format_retrieval_documents(documents)
return {
"query": {
"content": query,
},
"records": [record.model_dump() for record in records],
}
@classmethod
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]:
records = []
if dataset.provider == "external":
for document in documents:
record = {
"content": document.get("content", None),
"title": document.get("title", None),
"score": document.get("score", None),
"metadata": document.get("metadata", None),
}
records.append(record)
return {
"query": {"content": query},
"records": records,
}
return {"query": {"content": query}, "records": []}
@classmethod
def hit_testing_args_check(cls, args):
query = args["query"]
if not query or len(query) > 250:
raise ValueError("Query is required and cannot exceed 250 characters")
@staticmethod
def escape_query_for_search(query: str) -> str:
return query.replace('"', '\\"')

View File

@@ -0,0 +1,45 @@
import boto3
from configs import dify_config
class ExternalDatasetTestService:
# this service is only for internal testing
@staticmethod
def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str):
# get bedrock client
client = boto3.client(
"bedrock-agent-runtime",
aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY,
aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID,
# example: us-east-1
region_name="us-east-1",
)
# fetch external knowledge retrieval
response = client.retrieve(
knowledgeBaseId=knowledge_id,
retrievalConfiguration={
"vectorSearchConfiguration": {
"numberOfResults": retrieval_setting.get("top_k"),
"overrideSearchType": "HYBRID",
}
},
retrievalQuery={"text": query},
)
# parse response
results = []
if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:
if response.get("retrievalResults"):
retrieval_results = response.get("retrievalResults")
for retrieval_result in retrieval_results:
# filter out results with score less than threshold
if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0):
continue
result = {
"metadata": retrieval_result.get("metadata"),
"score": retrieval_result.get("score"),
"title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"),
"content": retrieval_result.get("content").get("text"),
}
results.append(result)
return {"records": results}

View File

@@ -0,0 +1,305 @@
import json
from typing import Union
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
from services.conversation_service import ConversationService
from services.errors.message import (
FirstMessageNotExistsError,
LastMessageNotExistsError,
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.workflow_service import WorkflowService
class MessageService:
@classmethod
def pagination_by_first_id(
cls,
app_model: App,
user: Union[Account, EndUser] | None,
conversation_id: str,
first_id: str | None,
limit: int,
order: str = "asc",
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
if not conversation_id:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
conversation = ConversationService.get_conversation(
app_model=app_model, user=user, conversation_id=conversation_id
)
fetch_limit = limit + 1
if first_id:
first_message = (
db.session.query(Message)
.where(Message.conversation_id == conversation.id, Message.id == first_id)
.first()
)
if not first_message:
raise FirstMessageNotExistsError()
history_messages = (
db.session.query(Message)
.where(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
Message.id != first_message.id,
)
.order_by(Message.created_at.desc())
.limit(fetch_limit)
.all()
)
else:
history_messages = (
db.session.query(Message)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(fetch_limit)
.all()
)
has_more = False
if len(history_messages) > limit:
has_more = True
history_messages = history_messages[:-1]
if order == "asc":
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
@classmethod
def pagination_by_last_id(
cls,
app_model: App,
user: Union[Account, EndUser] | None,
last_id: str | None,
limit: int,
conversation_id: str | None = None,
include_ids: list | None = None,
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
base_query = db.session.query(Message)
fetch_limit = limit + 1
if conversation_id is not None:
conversation = ConversationService.get_conversation(
app_model=app_model, user=user, conversation_id=conversation_id
)
base_query = base_query.where(Message.conversation_id == conversation.id)
# Check if include_ids is not None and not empty to avoid WHERE false condition
if include_ids is not None:
if len(include_ids) == 0:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
base_query = base_query.where(Message.id.in_(include_ids))
if last_id:
last_message = base_query.where(Message.id == last_id).first()
if not last_message:
raise LastMessageNotExistsError()
history_messages = (
base_query.where(Message.created_at < last_message.created_at, Message.id != last_message.id)
.order_by(Message.created_at.desc())
.limit(fetch_limit)
.all()
)
else:
history_messages = base_query.order_by(Message.created_at.desc()).limit(fetch_limit).all()
has_more = False
if len(history_messages) > limit:
has_more = True
history_messages = history_messages[:-1]
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
@classmethod
def create_feedback(
cls,
*,
app_model: App,
message_id: str,
user: Union[Account, EndUser] | None,
rating: str | None,
content: str | None,
):
if not user:
raise ValueError("user cannot be None")
message = cls.get_message(app_model=app_model, user=user, message_id=message_id)
feedback = message.user_feedback if isinstance(user, EndUser) else message.admin_feedback
if not rating and feedback:
db.session.delete(feedback)
elif rating and feedback:
feedback.rating = rating
feedback.content = content
elif not rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
assert rating is not None
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=rating,
content=content,
from_source=("user" if isinstance(user, EndUser) else "admin"),
from_end_user_id=(user.id if isinstance(user, EndUser) else None),
from_account_id=(user.id if isinstance(user, Account) else None),
)
db.session.add(feedback)
db.session.commit()
return feedback
@classmethod
def get_all_messages_feedbacks(cls, app_model: App, page: int, limit: int):
"""Get all feedbacks of an app"""
offset = (page - 1) * limit
feedbacks = (
db.session.query(MessageFeedback)
.where(MessageFeedback.app_id == app_model.id)
.order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc())
.limit(limit)
.offset(offset)
.all()
)
return [record.to_dict() for record in feedbacks]
@classmethod
def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
message = (
db.session.query(Message)
.where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
)
if not message:
raise MessageNotExistsError()
return message
@classmethod
def get_suggested_questions_after_answer(
cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str, invoke_from: InvokeFrom
) -> list[str]:
if not user:
raise ValueError("user cannot be None")
message = cls.get_message(app_model=app_model, user=user, message_id=message_id)
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=message.conversation_id, user=user
)
model_manager = ModelManager()
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow_service = WorkflowService()
if invoke_from == InvokeFrom.DEBUGGER:
workflow = workflow_service.get_draft_workflow(app_model=app_model)
else:
workflow = workflow_service.get_published_workflow(app_model=app_model)
if workflow is None:
return []
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
if not app_config.additional_features:
raise ValueError("Additional features not found")
if not app_config.additional_features.suggested_questions_after_answer:
raise SuggestedQuestionsAfterAnswerDisabledError()
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id, model_type=ModelType.LLM
)
else:
if not conversation.override_model_configs:
app_model_config = (
db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
)
else:
conversation_override_model_configs = json.loads(conversation.override_model_configs)
app_model_config = AppModelConfig(
id=conversation.app_model_config_id,
app_id=app_model.id,
)
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
if not app_model_config:
raise ValueError("did not find app model config")
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
if suggested_questions_after_answer.get("enabled", False) is False:
raise SuggestedQuestionsAfterAnswerDisabledError()
model_instance = model_manager.get_model_instance(
tenant_id=app_model.tenant_id,
provider=app_model_config.model_dict["provider"],
model_type=ModelType.LLM,
model=app_model_config.model_dict["name"],
)
# get memory of conversation (read-only)
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
histories = memory.get_history_prompt_text(
max_token_limit=3000,
message_limit=3,
)
with measure_time() as timer:
questions_sequence = LLMGenerator.generate_suggested_questions_after_answer(
tenant_id=app_model.tenant_id, histories=histories
)
questions: list[str] = list(questions_sequence)
# get tracing instance
trace_manager = TraceQueueManager(app_id=app_model.id)
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer
)
)
return questions

View File

@@ -0,0 +1,283 @@
import copy
import logging
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
MetadataOperationData,
)
logger = logging.getLogger(__name__)
class MetadataService:
@staticmethod
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
# check if metadata name is too long
if len(metadata_args.name) > 255:
raise ValueError("Metadata name cannot exceed 255 characters.")
current_user, current_tenant_id = current_account_with_tenant()
# check if metadata name already exists
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
.first()
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
if field.value == metadata_args.name:
raise ValueError("Metadata name already exists in Built-in fields.")
metadata = DatasetMetadata(
tenant_id=current_tenant_id,
dataset_id=dataset_id,
type=metadata_args.type,
name=metadata_args.name,
created_by=current_user.id,
)
db.session.add(metadata)
db.session.commit()
return metadata
@staticmethod
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
# check if metadata name is too long
if len(name) > 255:
raise ValueError("Metadata name cannot exceed 255 characters.")
lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists
current_user, current_tenant_id = current_account_with_tenant()
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=name)
.first()
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
if field.value == name:
raise ValueError("Metadata name already exists in Built-in fields.")
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
if metadata is None:
raise ValueError("Metadata not found.")
old_name = metadata.name
metadata.name = name
metadata.updated_by = current_user.id
metadata.updated_at = naive_utc_now()
# update related documents
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
for document in documents:
if not document.doc_metadata:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
value = doc_metadata.pop(old_name, None)
doc_metadata[name] = value
document.doc_metadata = doc_metadata
db.session.add(document)
db.session.commit()
return metadata
except Exception:
logger.exception("Update metadata name failed")
finally:
redis_client.delete(lock_key)
@staticmethod
def delete_metadata(dataset_id: str, metadata_id: str):
lock_key = f"dataset_metadata_lock_{dataset_id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
if metadata is None:
raise ValueError("Metadata not found.")
db.session.delete(metadata)
# deal related documents
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
for document in documents:
if not document.doc_metadata:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata.pop(metadata.name, None)
document.doc_metadata = doc_metadata
db.session.add(document)
db.session.commit()
return metadata
except Exception:
logger.exception("Delete metadata failed")
finally:
redis_client.delete(lock_key)
@staticmethod
def get_built_in_fields():
return [
{"name": BuiltInField.document_name, "type": "string"},
{"name": BuiltInField.uploader, "type": "string"},
{"name": BuiltInField.upload_date, "type": "time"},
{"name": BuiltInField.last_update_date, "type": "time"},
{"name": BuiltInField.source, "type": "string"},
]
@staticmethod
def enable_built_in_field(dataset: Dataset):
if dataset.built_in_field_enabled:
return
lock_key = f"dataset_metadata_lock_{dataset.id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
db.session.add(dataset)
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
if documents:
for document in documents:
if not document.doc_metadata:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata[BuiltInField.document_name] = document.name
doc_metadata[BuiltInField.uploader] = document.uploader
doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
document.doc_metadata = doc_metadata
db.session.add(document)
dataset.built_in_field_enabled = True
db.session.commit()
except Exception:
logger.exception("Enable built-in field failed")
finally:
redis_client.delete(lock_key)
@staticmethod
def disable_built_in_field(dataset: Dataset):
if not dataset.built_in_field_enabled:
return
lock_key = f"dataset_metadata_lock_{dataset.id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
db.session.add(dataset)
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
document_ids = []
if documents:
for document in documents:
if not document.doc_metadata:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata.pop(BuiltInField.document_name, None)
doc_metadata.pop(BuiltInField.uploader, None)
doc_metadata.pop(BuiltInField.upload_date, None)
doc_metadata.pop(BuiltInField.last_update_date, None)
doc_metadata.pop(BuiltInField.source, None)
document.doc_metadata = doc_metadata
db.session.add(document)
document_ids.append(document.id)
dataset.built_in_field_enabled = False
db.session.commit()
except Exception:
logger.exception("Disable built-in field failed")
finally:
redis_client.delete(lock_key)
@staticmethod
def update_documents_metadata(dataset: Dataset, metadata_args: MetadataOperationData):
for operation in metadata_args.operation_data:
lock_key = f"document_metadata_lock_{operation.document_id}"
try:
MetadataService.knowledge_base_metadata_lock_check(None, operation.document_id)
document = DocumentService.get_document(dataset.id, operation.document_id)
if document is None:
raise ValueError("Document not found.")
if operation.partial_update:
doc_metadata = copy.deepcopy(document.doc_metadata) if document.doc_metadata else {}
else:
doc_metadata = {}
for metadata_value in operation.metadata_list:
doc_metadata[metadata_value.name] = metadata_value.value
if dataset.built_in_field_enabled:
doc_metadata[BuiltInField.document_name] = document.name
doc_metadata[BuiltInField.uploader] = document.uploader
doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
document.doc_metadata = doc_metadata
db.session.add(document)
db.session.commit()
# deal metadata binding
if not operation.partial_update:
db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
current_user, current_tenant_id = current_account_with_tenant()
for metadata_value in operation.metadata_list:
# check if binding already exists
if operation.partial_update:
existing_binding = (
db.session.query(DatasetMetadataBinding)
.filter_by(document_id=operation.document_id, metadata_id=metadata_value.id)
.first()
)
if existing_binding:
continue
dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=current_tenant_id,
dataset_id=dataset.id,
document_id=operation.document_id,
metadata_id=metadata_value.id,
created_by=current_user.id,
)
db.session.add(dataset_metadata_binding)
db.session.commit()
except Exception:
logger.exception("Update documents metadata failed")
finally:
redis_client.delete(lock_key)
@staticmethod
def knowledge_base_metadata_lock_check(dataset_id: str | None, document_id: str | None):
if dataset_id:
lock_key = f"dataset_metadata_lock_{dataset_id}"
if redis_client.get(lock_key):
raise ValueError("Another knowledge base metadata operation is running, please wait a moment.")
redis_client.set(lock_key, 1, ex=3600)
if document_id:
lock_key = f"document_metadata_lock_{document_id}"
if redis_client.get(lock_key):
raise ValueError("Another document metadata operation is running, please wait a moment.")
redis_client.set(lock_key, 1, ex=3600)
@staticmethod
def get_dataset_metadatas(dataset: Dataset):
return {
"doc_metadata": [
{
"id": item.get("id"),
"name": item.get("name"),
"type": item.get("type"),
"count": db.session.query(DatasetMetadataBinding)
.filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
.count(),
}
for item in dataset.doc_metadata or []
if item.get("id") != "built-in"
],
"built_in_field_enabled": dataset.built_in_field_enabled,
}

View File

@@ -0,0 +1,620 @@
import json
import logging
from json import JSONDecodeError
from typing import Union
from sqlalchemy import or_, select
from constants import HIDDEN_VALUE
from core.entities.provider_configuration import ProviderConfiguration
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_manager import LBModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import (
ModelCredentialSchema,
ProviderCredentialSchema,
)
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.provider_manager import ProviderManager
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential
logger = logging.getLogger(__name__)
class ModelLoadBalancingService:
def __init__(self):
self.provider_manager = ProviderManager()
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
enable model load balancing.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Enable model load balancing
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
disable model load balancing.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# disable model load balancing
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def get_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
) -> tuple[bool, list[dict]]:
"""
Get load balancing configurations.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
# Get provider model setting
provider_model_setting = provider_configuration.get_provider_model_setting(
model_type=model_type_enum,
model=model,
)
is_load_balancing_enabled = False
if provider_model_setting and provider_model_setting.load_balancing_enabled:
is_load_balancing_enabled = True
if config_from == "predefined-model":
credential_source_type = "provider"
else:
credential_source_type = "custom_model"
# Get load balancing configurations
load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
or_(
LoadBalancingModelConfig.credential_source_type == credential_source_type,
LoadBalancingModelConfig.credential_source_type.is_(None),
),
)
.order_by(LoadBalancingModelConfig.created_at)
.all()
)
if provider_configuration.custom_configuration.provider:
# check if the inherit configuration exists,
# inherit is represented for the provider or model custom credentials
inherit_config_exists = False
for load_balancing_config in load_balancing_configs:
if load_balancing_config.name == "__inherit__":
inherit_config_exists = True
break
if not inherit_config_exists:
# Initialize the inherit configuration
inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type_enum)
# prepend the inherit configuration
load_balancing_configs.insert(0, inherit_config)
else:
# move the inherit configuration to the first
for i, load_balancing_config in enumerate(load_balancing_configs[:]):
if load_balancing_config.name == "__inherit__":
inherit_config = load_balancing_configs.pop(i)
load_balancing_configs.insert(0, inherit_config)
# Get credential form schemas from model credential schema or provider credential schema
credential_schemas = self._get_credential_schema(provider_configuration)
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
# fetch status and ttl for each config
datas = []
for load_balancing_config in load_balancing_configs:
in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl(
tenant_id=tenant_id,
provider=provider,
model=model,
model_type=model_type_enum,
config_id=load_balancing_config.id,
)
try:
if load_balancing_config.encrypted_config:
credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config)
else:
credentials = {}
except JSONDecodeError:
credentials = {}
# Get provider credential secret variables
credential_secret_variables = provider_configuration.extract_secret_variables(
credential_schemas.credential_form_schemas
)
# decrypt credentials
for variable in credential_secret_variables:
if variable in credentials:
try:
token_value = credentials.get(variable)
if isinstance(token_value, str):
credentials[variable] = encrypter.decrypt_token_with_decoding(
token_value,
decoding_rsa_key,
decoding_cipher_rsa,
)
except ValueError:
pass
# Obfuscate credentials
credentials = provider_configuration.obfuscated_credentials(
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
)
datas.append(
{
"id": load_balancing_config.id,
"name": load_balancing_config.name,
"credentials": credentials,
"credential_id": load_balancing_config.credential_id,
"enabled": load_balancing_config.enabled,
"in_cooldown": in_cooldown,
"ttl": ttl,
}
)
return is_load_balancing_enabled, datas
def get_load_balancing_config(
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
) -> dict | None:
"""
Get load balancing configuration.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:param config_id: load balancing config id
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
# Get load balancing configurations
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
.first()
)
if not load_balancing_model_config:
return None
try:
if load_balancing_model_config.encrypted_config:
credentials = json.loads(load_balancing_model_config.encrypted_config)
else:
credentials = {}
except JSONDecodeError:
credentials = {}
# Get credential form schemas from model credential schema or provider credential schema
credential_schemas = self._get_credential_schema(provider_configuration)
# Obfuscate credentials
credentials = provider_configuration.obfuscated_credentials(
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
)
return {
"id": load_balancing_model_config.id,
"name": load_balancing_model_config.name,
"credentials": credentials,
"enabled": load_balancing_model_config.enabled,
}
def _init_inherit_config(
self, tenant_id: str, provider: str, model: str, model_type: ModelType
) -> LoadBalancingModelConfig:
"""
Initialize the inherit configuration.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Initialize the inherit configuration
inherit_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
name="__inherit__",
)
db.session.add(inherit_config)
db.session.commit()
return inherit_config
def update_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict], config_from: str
):
"""
Update load balancing configurations.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:param configs: load balancing configs
:param config_from: predefined-model or custom-model
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
if not isinstance(configs, list):
raise ValueError("Invalid load balancing configs")
current_load_balancing_configs = db.session.scalars(
select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
).all()
# id as key, config as value
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
updated_config_ids = set()
for config in configs:
if not isinstance(config, dict):
raise ValueError("Invalid load balancing config")
config_id = config.get("id")
name = config.get("name")
credentials = config.get("credentials")
credential_id = config.get("credential_id")
enabled = config.get("enabled")
credential_record: ProviderCredential | ProviderModelCredential | None = None
if credential_id:
if config_from == "predefined-model":
credential_record = (
db.session.query(ProviderCredential)
.filter_by(
id=credential_id,
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
)
.first()
)
else:
credential_record = (
db.session.query(ProviderModelCredential)
.filter_by(
id=credential_id,
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_name=model,
model_type=model_type_enum.to_origin_model_type(),
)
.first()
)
if not credential_record:
raise ValueError(f"Provider credential with id {credential_id} not found")
name = credential_record.credential_name
if not name:
raise ValueError("Invalid load balancing config name")
if enabled is None:
raise ValueError("Invalid load balancing config enabled")
# is config exists
if config_id:
config_id = str(config_id)
if config_id not in current_load_balancing_configs_dict:
raise ValueError(f"Invalid load balancing config id: {config_id}")
updated_config_ids.add(config_id)
load_balancing_config = current_load_balancing_configs_dict[config_id]
if credentials:
if not isinstance(credentials, dict):
raise ValueError("Invalid load balancing config credentials")
# validate custom provider config
credentials = self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type_enum,
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_config,
validate=False,
)
# update load balancing config
load_balancing_config.encrypted_config = json.dumps(credentials)
load_balancing_config.name = name
load_balancing_config.enabled = enabled
load_balancing_config.updated_at = naive_utc_now()
db.session.commit()
self._clear_credentials_cache(tenant_id, config_id)
else:
# create load balancing config
if name == "__inherit__":
raise ValueError("Invalid load balancing config name")
if credential_id:
credential_source = "provider" if config_from == "predefined-model" else "custom_model"
assert credential_record is not None
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_type=model_type_enum.to_origin_model_type(),
model_name=model,
name=credential_record.credential_name,
encrypted_config=credential_record.encrypted_config,
credential_id=credential_id,
credential_source_type=credential_source,
)
else:
if not credentials:
raise ValueError("Invalid load balancing config credentials")
if not isinstance(credentials, dict):
raise ValueError("Invalid load balancing config credentials")
# validate custom provider config
credentials = self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type_enum,
model=model,
credentials=credentials,
validate=False,
)
# create load balancing config
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_type=model_type_enum.to_origin_model_type(),
model_name=model,
name=name,
encrypted_config=json.dumps(credentials),
)
db.session.add(load_balancing_model_config)
db.session.commit()
# get deleted config ids
deleted_config_ids = set(current_load_balancing_configs_dict.keys()) - updated_config_ids
for config_id in deleted_config_ids:
db.session.delete(current_load_balancing_configs_dict[config_id])
db.session.commit()
self._clear_credentials_cache(tenant_id, config_id)
def validate_load_balancing_credentials(
self,
tenant_id: str,
provider: str,
model: str,
model_type: str,
credentials: dict,
config_id: str | None = None,
):
"""
Validate load balancing credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: credentials
:param config_id: load balancing config id
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
load_balancing_model_config = None
if config_id:
# Get load balancing config
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
.first()
)
if not load_balancing_model_config:
raise ValueError(f"Load balancing config {config_id} does not exist.")
# Validate custom provider config
self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type_enum,
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_model_config,
)
def _custom_credentials_validate(
self,
tenant_id: str,
provider_configuration: ProviderConfiguration,
model_type: ModelType,
model: str,
credentials: dict,
load_balancing_model_config: LoadBalancingModelConfig | None = None,
validate: bool = True,
):
"""
Validate custom credentials.
:param tenant_id: workspace id
:param provider_configuration: provider configuration
:param model_type: model type
:param model: model name
:param credentials: credentials
:param load_balancing_model_config: load balancing model config
:param validate: validate credentials
:return:
"""
# Get credential form schemas from model credential schema or provider credential schema
credential_schemas = self._get_credential_schema(provider_configuration)
# Get provider credential secret variables
provider_credential_secret_variables = provider_configuration.extract_secret_variables(
credential_schemas.credential_form_schemas
)
if load_balancing_model_config:
try:
# fix origin data
if load_balancing_model_config.encrypted_config:
original_credentials = json.loads(load_balancing_model_config.encrypted_config)
else:
original_credentials = {}
except JSONDecodeError:
original_credentials = {}
# encrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key])
if validate:
model_provider_factory = ModelProviderFactory(tenant_id)
if isinstance(credential_schemas, ModelCredentialSchema):
credentials = model_provider_factory.model_credentials_validate(
provider=provider_configuration.provider.provider,
model_type=model_type,
model=model,
credentials=credentials,
)
else:
credentials = model_provider_factory.provider_credentials_validate(
provider=provider_configuration.provider.provider, credentials=credentials
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(tenant_id, value)
return credentials
def _get_credential_schema(
self, provider_configuration: ProviderConfiguration
) -> Union[ModelCredentialSchema, ProviderCredentialSchema]:
"""Get form schemas."""
if provider_configuration.provider.model_credential_schema:
return provider_configuration.provider.model_credential_schema
elif provider_configuration.provider.provider_credential_schema:
return provider_configuration.provider.provider_credential_schema
else:
raise ValueError("No credential schema found")
def _clear_credentials_cache(self, tenant_id: str, config_id: str):
"""
Clear credentials cache.
:param tenant_id: workspace id
:param config_id: load balancing config id
:return:
"""
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
)
provider_model_credentials_cache.delete()

View File

@@ -0,0 +1,554 @@
import logging
from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.provider_manager import ProviderManager
from models.provider import ProviderType
from services.entities.model_provider_entities import (
CustomConfigurationResponse,
CustomConfigurationStatus,
DefaultModelResponse,
ModelWithProviderEntityResponse,
ProviderResponse,
ProviderWithModelsResponse,
SimpleProviderEntityResponse,
SystemConfigurationResponse,
)
from services.errors.app_model_config import ProviderNotFoundError
logger = logging.getLogger(__name__)
class ModelProviderService:
"""
Model Provider Service
"""
def __init__(self):
self.provider_manager = ProviderManager()
def _get_provider_configuration(self, tenant_id: str, provider: str):
"""
Get provider configuration or raise exception if not found.
Args:
tenant_id: Workspace identifier
provider: Provider name
Returns:
Provider configuration instance
Raises:
ProviderNotFoundError: If provider doesn't exist
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ProviderNotFoundError(f"Provider {provider} does not exist.")
return provider_configuration
def get_provider_list(self, tenant_id: str, model_type: str | None = None) -> list[ProviderResponse]:
"""
get provider list.
:param tenant_id: workspace id
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
provider_responses = []
for provider_configuration in provider_configurations.values():
if model_type:
model_type_entity = ModelType.value_of(model_type)
if model_type_entity not in provider_configuration.provider.supported_model_types:
continue
provider_config = provider_configuration.custom_configuration.provider
model_config = provider_configuration.custom_configuration.models
can_added_models = provider_configuration.custom_configuration.can_added_models
provider_response = ProviderResponse(
tenant_id=tenant_id,
provider=provider_configuration.provider.provider,
label=provider_configuration.provider.label,
description=provider_configuration.provider.description,
icon_small=provider_configuration.provider.icon_small,
icon_large=provider_configuration.provider.icon_large,
background=provider_configuration.provider.background,
help=provider_configuration.provider.help,
supported_model_types=provider_configuration.provider.supported_model_types,
configurate_methods=provider_configuration.provider.configurate_methods,
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
model_credential_schema=provider_configuration.provider.model_credential_schema,
preferred_provider_type=provider_configuration.preferred_provider_type,
custom_configuration=CustomConfigurationResponse(
status=CustomConfigurationStatus.ACTIVE
if provider_configuration.is_custom_configuration_available()
else CustomConfigurationStatus.NO_CONFIGURE,
current_credential_id=getattr(provider_config, "current_credential_id", None),
current_credential_name=getattr(provider_config, "current_credential_name", None),
available_credentials=getattr(provider_config, "available_credentials", []),
custom_models=model_config,
can_added_models=can_added_models,
),
system_configuration=SystemConfigurationResponse(
enabled=provider_configuration.system_configuration.enabled,
current_quota_type=provider_configuration.system_configuration.current_quota_type,
quota_configurations=provider_configuration.system_configuration.quota_configurations,
),
)
provider_responses.append(provider_response)
return provider_responses
def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWithProviderEntityResponse]:
"""
get provider models.
For the model provider page,
only supports passing in a single provider to query the list of supported models.
:param tenant_id: workspace id
:param provider: provider name
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider available models
return [
ModelWithProviderEntityResponse(tenant_id=tenant_id, model=model)
for model in provider_configurations.get_models(provider=provider)
]
def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None:
"""
get provider credentials.
:param tenant_id: workspace id
:param provider: provider name
:param credential_id: credential id, if not provided, return current used credentials
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_provider_credential(credential_id=credential_id)
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
"""
validate provider credentials before saving.
:param tenant_id: workspace id
:param provider: provider name
:param credentials: provider credentials dict
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.validate_provider_credentials(credentials)
def create_provider_credential(
self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None
) -> None:
"""
Create and save new provider credentials.
:param tenant_id: workspace id
:param provider: provider name
:param credentials: provider credentials dict
:param credential_name: credential name
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.create_provider_credential(credentials, credential_name)
def update_provider_credential(
self,
tenant_id: str,
provider: str,
credentials: dict,
credential_id: str,
credential_name: str | None,
) -> None:
"""
update a saved provider credential (by credential_id).
:param tenant_id: workspace id
:param provider: provider name
:param credentials: provider credentials dict
:param credential_id: credential id
:param credential_name: credential name
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.update_provider_credential(
credential_id=credential_id,
credentials=credentials,
credential_name=credential_name,
)
def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
"""
remove a saved provider credential (by credential_id).
:param tenant_id: workspace id
:param provider: provider name
:param credential_id: credential id
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.delete_provider_credential(credential_id=credential_id)
def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
"""
:param tenant_id: workspace id
:param provider: provider name
:param credential_id: credential id
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.switch_active_provider_credential(credential_id=credential_id)
def get_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None
) -> dict | None:
"""
Retrieve model-specific credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credential_id: Optional credential ID, uses current if not provided
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_custom_model_credential(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict):
"""
validate model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: model credentials dict
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.validate_custom_model_credentials(
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
)
def create_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None
) -> None:
"""
create and save model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: model credentials dict
:param credential_name: credential name
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.create_custom_model_credential(
model_type=ModelType.value_of(model_type),
model=model,
credentials=credentials,
credential_name=credential_name,
)
def update_model_credential(
self,
tenant_id: str,
provider: str,
model_type: str,
model: str,
credentials: dict,
credential_id: str,
credential_name: str | None,
) -> None:
"""
update model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: model credentials dict
:param credential_id: credential id
:param credential_name: credential name
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.update_custom_model_credential(
model_type=ModelType.value_of(model_type),
model=model,
credentials=credentials,
credential_id=credential_id,
credential_name=credential_name,
)
def remove_model_credential(self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str):
"""
remove model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credential_id: credential id
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.delete_custom_model_credential(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def switch_active_custom_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
):
"""
switch model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credential_id: credential id
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.switch_custom_model_credential(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def add_model_credential_to_model_list(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
):
"""
add model credentials to model list.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credential_id: credential id
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.add_model_credential_to_model(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str):
"""
remove model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.delete_custom_model(model_type=ModelType.value_of(model_type), model=model)
def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
"""
get models by model type.
:param tenant_id: workspace id
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider available models
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type), only_active=True)
# Group models by provider
provider_models: dict[str, list[ModelWithProviderEntity]] = {}
for model in models:
if model.provider.provider not in provider_models:
provider_models[model.provider.provider] = []
if model.deprecated:
continue
provider_models[model.provider.provider].append(model)
# convert to ProviderWithModelsResponse list
providers_with_models: list[ProviderWithModelsResponse] = []
for provider, models in provider_models.items():
if not models:
continue
first_model = models[0]
providers_with_models.append(
ProviderWithModelsResponse(
tenant_id=tenant_id,
provider=provider,
label=first_model.provider.label,
icon_small=first_model.provider.icon_small,
icon_large=first_model.provider.icon_large,
status=CustomConfigurationStatus.ACTIVE,
models=[
ProviderModelWithStatusEntity(
model=model.model,
label=model.label,
model_type=model.model_type,
features=model.features,
fetch_from=model.fetch_from,
model_properties=model.model_properties,
status=model.status,
load_balancing_enabled=model.load_balancing_enabled,
)
for model in models
],
)
)
return providers_with_models
def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) -> list[ParameterRule]:
"""
get model parameter rules.
Only supports LLM.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
# fetch credentials
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model)
if not credentials:
return []
model_schema = provider_configuration.get_model_schema(
model_type=ModelType.LLM, model=model, credentials=credentials
)
return model_schema.parameter_rules if model_schema else []
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> DefaultModelResponse | None:
"""
get default model of model type.
:param tenant_id: workspace id
:param model_type: model type
:return:
"""
model_type_enum = ModelType.value_of(model_type)
try:
result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum)
return (
DefaultModelResponse(
model=result.model,
model_type=result.model_type,
provider=SimpleProviderEntityResponse(
tenant_id=tenant_id,
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types,
),
)
if result
else None
)
except Exception as e:
logger.debug("get_default_model_of_model_type error: %s", e)
return None
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str):
"""
update default model of model type.
:param tenant_id: workspace id
:param model_type: model type
:param provider: provider name
:param model: model name
:return:
"""
model_type_enum = ModelType.value_of(model_type)
self.provider_manager.update_default_model_record(
tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model
)
def get_model_provider_icon(
self, tenant_id: str, provider: str, icon_type: str, lang: str
) -> tuple[bytes | None, str | None]:
"""
get model provider icon.
:param tenant_id: workspace id
:param provider: provider name
:param icon_type: icon type (icon_small or icon_large)
:param lang: language (zh_Hans or en_US)
:return:
"""
model_provider_factory = ModelProviderFactory(tenant_id)
byte_data, mime_type = model_provider_factory.get_provider_icon(provider, icon_type, lang)
return byte_data, mime_type
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str):
"""
switch preferred provider.
:param tenant_id: workspace id
:param provider: provider name
:param preferred_provider_type: preferred provider type
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
# Convert preferred_provider_type to ProviderType
preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type)
# Switch preferred provider type
provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
enable model.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
disable model.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))

View File

@@ -0,0 +1,94 @@
import enum
import uuid
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Account
from models.model import OAuthProviderApp
from services.account_service import AccountService
class OAuthGrantType(enum.StrEnum):
AUTHORIZATION_CODE = "authorization_code"
REFRESH_TOKEN = "refresh_token"
OAUTH_AUTHORIZATION_CODE_REDIS_KEY = "oauth_provider:{client_id}:authorization_code:{code}"
OAUTH_ACCESS_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:access_token:{token}"
OAUTH_ACCESS_TOKEN_EXPIRES_IN = 60 * 60 * 12 # 12 hours
OAUTH_REFRESH_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:refresh_token:{token}"
OAUTH_REFRESH_TOKEN_EXPIRES_IN = 60 * 60 * 24 * 30 # 30 days
class OAuthServerService:
@staticmethod
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
with Session(db.engine) as session:
return session.execute(query).scalar_one_or_none()
@staticmethod
def sign_oauth_authorization_code(client_id: str, user_account_id: str) -> str:
code = str(uuid.uuid4())
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
redis_client.set(redis_key, user_account_id, ex=60 * 10) # 10 minutes
return code
@staticmethod
def sign_oauth_access_token(
grant_type: OAuthGrantType,
code: str = "",
client_id: str = "",
refresh_token: str = "",
) -> tuple[str, str]:
match grant_type:
case OAuthGrantType.AUTHORIZATION_CODE:
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
user_account_id = redis_client.get(redis_key)
if not user_account_id:
raise BadRequest("invalid code")
# delete code
redis_client.delete(redis_key)
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
refresh_token = OAuthServerService._sign_oauth_refresh_token(client_id, user_account_id)
return access_token, refresh_token
case OAuthGrantType.REFRESH_TOKEN:
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=refresh_token)
user_account_id = redis_client.get(redis_key)
if not user_account_id:
raise BadRequest("invalid refresh token")
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
return access_token, refresh_token
@staticmethod
def _sign_oauth_access_token(client_id: str, user_account_id: str) -> str:
token = str(uuid.uuid4())
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
redis_client.set(redis_key, user_account_id, ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN)
return token
@staticmethod
def _sign_oauth_refresh_token(client_id: str, user_account_id: str) -> str:
token = str(uuid.uuid4())
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
redis_client.set(redis_key, user_account_id, ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN)
return token
@staticmethod
def validate_oauth_access_token(client_id: str, token: str) -> Account | None:
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
user_account_id = redis_client.get(redis_key)
if not user_account_id:
return None
user_id_str = user_account_id.decode("utf-8")
return AccountService.load_user(user_id_str)

View File

@@ -0,0 +1,29 @@
import os
import httpx
class OperationService:
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
@classmethod
def _send_request(cls, method, endpoint, json=None, params=None):
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = httpx.request(method, url, json=json, params=params, headers=headers)
return response.json()
@classmethod
def record_utm(cls, tenant_id: str, utm_info: dict):
params = {
"tenant_id": tenant_id,
"utm_source": utm_info.get("utm_source", ""),
"utm_medium": utm_info.get("utm_medium", ""),
"utm_campaign": utm_info.get("utm_campaign", ""),
"utm_content": utm_info.get("utm_content", ""),
"utm_term": utm_info.get("utm_term", ""),
}
return cls._send_request("POST", "/tenant_utms", params=params)

View File

@@ -0,0 +1,276 @@
from typing import Any
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
from extensions.ext_database import db
from models.model import App, TraceAppConfig
class OpsService:
@classmethod
def get_tracing_app_config(cls, app_id: str, tracing_provider: str):
"""
Get tracing app config
:param app_id: app id
:param tracing_provider: tracing provider
:return:
"""
trace_config_data: TraceAppConfig | None = (
db.session.query(TraceAppConfig)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
if not trace_config_data:
return None
# decrypt_token and obfuscated_token
app = db.session.query(App).where(App.id == app_id).first()
if not app:
return None
tenant_id = app.tenant_id
if trace_config_data.tracing_config is None:
raise ValueError("Tracing config cannot be None.")
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(
tenant_id, tracing_provider, trace_config_data.tracing_config
)
new_decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config)
if tracing_provider == "arize" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://app.arize.com/"})
if tracing_provider == "phoenix" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://app.phoenix.arize.com/projects/"})
if tracing_provider == "langfuse" and (
"project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key")
):
try:
project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update(
{
"project_url": "{host}/project/{key}".format(
host=decrypt_tracing_config.get("host"), key=project_key
)
}
)
except Exception:
new_decrypt_tracing_config.update({"project_url": f"{decrypt_tracing_config.get('host')}/"})
if tracing_provider == "langsmith" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://smith.langchain.com/"})
if tracing_provider == "opik" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://www.comet.com/opik/"})
if tracing_provider == "weave" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://wandb.ai/"})
if tracing_provider == "aliyun" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://arms.console.aliyun.com/"})
if tracing_provider == "tencent" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://console.cloud.tencent.com/apm"})
if tracing_provider == "mlflow" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "http://localhost:5000/"})
if tracing_provider == "databricks" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://www.databricks.com/"})
trace_config_data.tracing_config = new_decrypt_tracing_config
return trace_config_data.to_dict()
@classmethod
def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
"""
Create tracing app config
:param app_id: app id
:param tracing_provider: tracing provider
:param tracing_config: tracing config
:return:
"""
try:
provider_config_map[tracing_provider]
except KeyError:
return {"error": f"Invalid tracing provider: {tracing_provider}"}
provider_config: dict[str, Any] = provider_config_map[tracing_provider]
config_class: type[BaseTracingConfig] = provider_config["config_class"]
other_keys: list[str] = provider_config["other_keys"]
default_config_instance = config_class.model_validate(tracing_config)
for key in other_keys:
if key in tracing_config and tracing_config[key] == "":
tracing_config[key] = getattr(default_config_instance, key, None)
# api check
if not OpsTraceManager.check_trace_config_is_effective(tracing_config, tracing_provider):
return {"error": "Invalid Credentials"}
# get project url
if tracing_provider in ("arize", "phoenix"):
try:
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
except Exception:
project_url = None
elif tracing_provider == "langfuse":
try:
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
project_url = f"{tracing_config.get('host')}/project/{project_key}"
except Exception:
project_url = None
elif tracing_provider in ("langsmith", "opik", "mlflow", "databricks", "tencent"):
try:
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
except Exception:
project_url = None
else:
project_url = None
# check if trace config already exists
trace_config_data: TraceAppConfig | None = (
db.session.query(TraceAppConfig)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
if trace_config_data:
return None
# get tenant id
app = db.session.query(App).where(App.id == app_id).first()
if not app:
return None
tenant_id = app.tenant_id
tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config)
if project_url:
tracing_config["project_url"] = project_url
trace_config_data = TraceAppConfig(
app_id=app_id,
tracing_provider=tracing_provider,
tracing_config=tracing_config,
)
db.session.add(trace_config_data)
db.session.commit()
return {"result": "success"}
@classmethod
def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
"""
Update tracing app config
:param app_id: app id
:param tracing_provider: tracing provider
:param tracing_config: tracing config
:return:
"""
try:
provider_config_map[tracing_provider]
except KeyError:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
# check if trace config already exists
current_trace_config = (
db.session.query(TraceAppConfig)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
if not current_trace_config:
return None
# get tenant id
app = db.session.query(App).where(App.id == app_id).first()
if not app:
return None
tenant_id = app.tenant_id
tracing_config = OpsTraceManager.encrypt_tracing_config(
tenant_id, tracing_provider, tracing_config, current_trace_config.tracing_config
)
# api check
# decrypt_token
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, tracing_config)
if not OpsTraceManager.check_trace_config_is_effective(decrypt_tracing_config, tracing_provider):
raise ValueError("Invalid Credentials")
current_trace_config.tracing_config = tracing_config
db.session.commit()
return current_trace_config.to_dict()
@classmethod
def delete_tracing_app_config(cls, app_id: str, tracing_provider: str):
"""
Delete tracing app config
:param app_id: app id
:param tracing_provider: tracing provider
:return:
"""
trace_config = (
db.session.query(TraceAppConfig)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
if not trace_config:
return None
db.session.delete(trace_config)
db.session.commit()
return True

View File

View File

@@ -0,0 +1,212 @@
import json
import logging
import click
import sqlalchemy as sa
from extensions.ext_database import db
from models.provider_ids import GenericProviderID, ModelProviderID, ToolProviderID
logger = logging.getLogger(__name__)
class PluginDataMigration:
@classmethod
def migrate(cls):
cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID)
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID)
cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID)
cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID)
cls.migrate_datasets()
cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table
cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID)
cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
@classmethod
def migrate_datasets(cls):
table_name = "datasets"
provider_column_name = "embedding_model_provider"
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0
failed_ids = []
while True:
sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
limit 1000"""
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql))
current_iter_count = 0
for i in rs:
record_id = str(i.id)
provider_name = str(i.provider_name)
retrieval_model = i.retrieval_model
logger.debug(
"Processing dataset %s with retrieval model of type %s",
record_id,
type(retrieval_model),
)
if record_id in failed_ids:
continue
retrieval_model_changed = False
if retrieval_model:
if (
"reranking_model" in retrieval_model
and "reranking_provider_name" in retrieval_model["reranking_model"]
and retrieval_model["reranking_model"]["reranking_provider_name"]
and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
):
click.echo(
click.style(
f"[{processed_count}] Migrating {table_name} {record_id} "
f"(reranking_provider_name: "
f"{retrieval_model['reranking_model']['reranking_provider_name']})",
fg="white",
)
)
# update google to langgenius/gemini/google etc.
retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID(
retrieval_model["reranking_model"]["reranking_provider_name"]
).to_string()
retrieval_model_changed = True
click.echo(
click.style(
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
fg="white",
)
)
try:
# update provider name append with "langgenius/{provider_name}/{provider_name}"
params = {"record_id": record_id}
update_retrieval_model_sql = ""
if retrieval_model and retrieval_model_changed:
update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
params["retrieval_model"] = json.dumps(retrieval_model)
params["provider_name"] = ModelProviderID(provider_name).to_string()
sql = f"""update {table_name}
set {provider_column_name} =
:provider_name
{update_retrieval_model_sql}
where id = :record_id"""
conn.execute(sa.text(sql), params)
click.echo(
click.style(
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
fg="green",
)
)
except Exception:
failed_ids.append(record_id)
click.echo(
click.style(
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
fg="red",
)
)
logger.exception(
"[%s] Failed to migrate [%s] %s (%s)", processed_count, table_name, record_id, provider_name
)
continue
current_iter_count += 1
processed_count += 1
if not current_iter_count:
break
click.echo(
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
)
@classmethod
def migrate_db_records(cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]):
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0
failed_ids = []
last_id = "00000000-0000-0000-0000-000000000000"
while True:
sql = f"""
SELECT id, {provider_column_name} AS provider_name
FROM {table_name}
WHERE {provider_column_name} NOT LIKE '%/%'
AND {provider_column_name} IS NOT NULL
AND {provider_column_name} != ''
AND id > :last_id
ORDER BY id ASC
LIMIT 5000
"""
params = {"last_id": last_id or ""}
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql), params)
current_iter_count = 0
batch_updates = []
for i in rs:
current_iter_count += 1
processed_count += 1
record_id = str(i.id)
last_id = record_id
provider_name = str(i.provider_name)
if record_id in failed_ids:
continue
click.echo(
click.style(
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
fg="white",
)
)
try:
# update jina to langgenius/jina_tool/jina etc.
updated_value = provider_cls(provider_name).to_string()
batch_updates.append((updated_value, record_id))
except Exception:
failed_ids.append(record_id)
click.echo(
click.style(
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
fg="red",
)
)
logger.exception(
"[%s] Failed to migrate [%s] %s (%s)", processed_count, table_name, record_id, provider_name
)
continue
if batch_updates:
update_sql = f"""
UPDATE {table_name}
SET {provider_column_name} = :updated_value
WHERE id = :record_id
"""
conn.execute(sa.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
click.echo(
click.style(
f"[{processed_count}] Batch migrated [{len(batch_updates)}] records from [{table_name}]",
fg="green",
)
)
if not current_iter_count:
break
click.echo(
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
)

View File

@@ -0,0 +1,137 @@
import re
from configs import dify_config
from core.helper import marketplace
from core.plugin.entities.plugin import PluginDependency, PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller
from models.provider_ids import ModelProviderID, ToolProviderID
# Compile regex pattern for version extraction at module level for better performance
_VERSION_REGEX = re.compile(r":(?P<version>[0-9]+(?:\.[0-9]+){2}(?:[+-][0-9A-Za-z.-]+)?)(?:@|$)")
class DependenciesAnalysisService:
@classmethod
def analyze_tool_dependency(cls, tool_id: str) -> str:
"""
Analyze the dependency of a tool.
Convert the tool id to the plugin_id
"""
try:
return ToolProviderID(tool_id).plugin_id
except Exception as e:
raise e
@classmethod
def analyze_model_provider_dependency(cls, model_provider_id: str) -> str:
"""
Analyze the dependency of a model provider.
Convert the model provider id to the plugin_id
"""
try:
return ModelProviderID(model_provider_id).plugin_id
except Exception as e:
raise e
@classmethod
def get_leaked_dependencies(cls, tenant_id: str, dependencies: list[PluginDependency]) -> list[PluginDependency]:
"""
Check dependencies, returns the leaked dependencies in current workspace
"""
required_plugin_unique_identifiers = []
for dependency in dependencies:
required_plugin_unique_identifiers.append(dependency.value.plugin_unique_identifier)
manager = PluginInstaller()
# get leaked dependencies
missing_plugins = manager.fetch_missing_dependencies(tenant_id, required_plugin_unique_identifiers)
missing_plugin_unique_identifiers = {plugin.plugin_unique_identifier: plugin for plugin in missing_plugins}
leaked_dependencies = []
for dependency in dependencies:
unique_identifier = dependency.value.plugin_unique_identifier
if unique_identifier in missing_plugin_unique_identifiers:
# Extract version for Marketplace dependencies
if dependency.type == PluginDependency.Type.Marketplace:
version_match = _VERSION_REGEX.search(unique_identifier)
if version_match:
dependency.value.version = version_match.group("version")
# Create and append the dependency (same for all types)
leaked_dependencies.append(
PluginDependency(
type=dependency.type,
value=dependency.value,
current_identifier=missing_plugin_unique_identifiers[unique_identifier].current_identifier,
)
)
return leaked_dependencies
@classmethod
def generate_dependencies(cls, tenant_id: str, dependencies: list[str]) -> list[PluginDependency]:
"""
Generate dependencies through the list of plugin ids
"""
dependencies = list(set(dependencies))
manager = PluginInstaller()
plugins = manager.fetch_plugin_installation_by_ids(tenant_id, dependencies)
result = []
for plugin in plugins:
if plugin.source == PluginInstallationSource.Github:
result.append(
PluginDependency(
type=PluginDependency.Type.Github,
value=PluginDependency.Github(
repo=plugin.meta["repo"],
version=plugin.meta["version"],
package=plugin.meta["package"],
github_plugin_unique_identifier=plugin.plugin_unique_identifier,
),
)
)
elif plugin.source == PluginInstallationSource.Marketplace:
result.append(
PluginDependency(
type=PluginDependency.Type.Marketplace,
value=PluginDependency.Marketplace(
marketplace_plugin_unique_identifier=plugin.plugin_unique_identifier
),
)
)
elif plugin.source == PluginInstallationSource.Package:
result.append(
PluginDependency(
type=PluginDependency.Type.Package,
value=PluginDependency.Package(plugin_unique_identifier=plugin.plugin_unique_identifier),
)
)
elif plugin.source == PluginInstallationSource.Remote:
raise ValueError(
f"You used a remote plugin: {plugin.plugin_unique_identifier} in the app, please remove it first"
" if you want to export the DSL."
)
else:
raise ValueError(f"Unknown plugin source: {plugin.source}")
return result
@classmethod
def generate_latest_dependencies(cls, dependencies: list[str]) -> list[PluginDependency]:
"""
Generate the latest version of dependencies
"""
dependencies = list(set(dependencies))
if not dify_config.MARKETPLACE_ENABLED:
return []
deps = marketplace.batch_fetch_plugin_manifests(dependencies)
return [
PluginDependency(
type=PluginDependency.Type.Marketplace,
value=PluginDependency.Marketplace(marketplace_plugin_unique_identifier=dep.latest_package_identifier),
)
for dep in deps
]

View File

@@ -0,0 +1,66 @@
from core.plugin.impl.endpoint import PluginEndpointClient
class EndpointService:
@classmethod
def create_endpoint(cls, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict):
return PluginEndpointClient().create_endpoint(
tenant_id=tenant_id,
user_id=user_id,
plugin_unique_identifier=plugin_unique_identifier,
name=name,
settings=settings,
)
@classmethod
def list_endpoints(cls, tenant_id: str, user_id: str, page: int, page_size: int):
return PluginEndpointClient().list_endpoints(
tenant_id=tenant_id,
user_id=user_id,
page=page,
page_size=page_size,
)
@classmethod
def list_endpoints_for_single_plugin(cls, tenant_id: str, user_id: str, plugin_id: str, page: int, page_size: int):
return PluginEndpointClient().list_endpoints_for_single_plugin(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
page=page,
page_size=page_size,
)
@classmethod
def update_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
return PluginEndpointClient().update_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
name=name,
settings=settings,
)
@classmethod
def delete_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
return PluginEndpointClient().delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
@classmethod
def enable_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
return PluginEndpointClient().enable_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
@classmethod
def disable_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
return PluginEndpointClient().disable_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)

View File

@@ -0,0 +1,65 @@
import json
import uuid
from core.plugin.impl.base import BasePluginClient
from extensions.ext_redis import redis_client
class OAuthProxyService(BasePluginClient):
# Default max age for proxy context parameter in seconds
__MAX_AGE__ = 5 * 60 # 5 minutes
__KEY_PREFIX__ = "oauth_proxy_context:"
@staticmethod
def create_proxy_context(
user_id: str,
tenant_id: str,
plugin_id: str,
provider: str,
extra_data: dict = {},
credential_id: str | None = None,
):
"""
Create a proxy context for an OAuth 2.0 authorization request.
This parameter is a crucial security measure to prevent Cross-Site Request
Forgery (CSRF) attacks. It works by generating a unique nonce and storing it
in a distributed cache (Redis) along with the user's session context.
The returned nonce should be included as the 'proxy_context' parameter in the
authorization URL. Upon callback, the `use_proxy_context` method
is used to verify the state, ensuring the request's integrity and authenticity,
and mitigating replay attacks.
"""
context_id = str(uuid.uuid4())
data = {
**extra_data,
"user_id": user_id,
"plugin_id": plugin_id,
"tenant_id": tenant_id,
"provider": provider,
}
if credential_id:
data["credential_id"] = credential_id
redis_client.setex(
f"{OAuthProxyService.__KEY_PREFIX__}{context_id}",
OAuthProxyService.__MAX_AGE__,
json.dumps(data),
)
return context_id
@staticmethod
def use_proxy_context(context_id: str):
"""
Validate the proxy context parameter.
This checks if the context_id is valid and not expired.
"""
if not context_id:
raise ValueError("context_id is required")
# get data from redis
key = f"{OAuthProxyService.__KEY_PREFIX__}{context_id}"
data = redis_client.get(key)
if not data:
raise ValueError("context_id is invalid")
redis_client.delete(key)
return json.loads(data)

View File

@@ -0,0 +1,87 @@
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.account import TenantPluginAutoUpgradeStrategy
class PluginAutoUpgradeService:
@staticmethod
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
with Session(db.engine) as session:
return (
session.query(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
)
@staticmethod
def change_strategy(
tenant_id: str,
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting,
upgrade_time_of_day: int,
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode,
exclude_plugins: list[str],
include_plugins: list[str],
) -> bool:
with Session(db.engine) as session:
exist_strategy = (
session.query(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
)
if not exist_strategy:
strategy = TenantPluginAutoUpgradeStrategy(
tenant_id=tenant_id,
strategy_setting=strategy_setting,
upgrade_time_of_day=upgrade_time_of_day,
upgrade_mode=upgrade_mode,
exclude_plugins=exclude_plugins,
include_plugins=include_plugins,
)
session.add(strategy)
else:
exist_strategy.strategy_setting = strategy_setting
exist_strategy.upgrade_time_of_day = upgrade_time_of_day
exist_strategy.upgrade_mode = upgrade_mode
exist_strategy.exclude_plugins = exclude_plugins
exist_strategy.include_plugins = include_plugins
session.commit()
return True
@staticmethod
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
with Session(db.engine) as session:
exist_strategy = (
session.query(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
)
if not exist_strategy:
# create for this tenant
PluginAutoUpgradeService.change_strategy(
tenant_id,
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
0,
TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
[plugin_id],
[],
)
return True
else:
if exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE:
if plugin_id not in exist_strategy.exclude_plugins:
new_exclude_plugins = exist_strategy.exclude_plugins.copy()
new_exclude_plugins.append(plugin_id)
exist_strategy.exclude_plugins = new_exclude_plugins
elif exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL:
if plugin_id in exist_strategy.include_plugins:
new_include_plugins = exist_strategy.include_plugins.copy()
new_include_plugins.remove(plugin_id)
exist_strategy.include_plugins = new_include_plugins
elif exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL:
exist_strategy.upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
exist_strategy.exclude_plugins = [plugin_id]
session.commit()
return True

View File

@@ -0,0 +1,595 @@
import datetime
import json
import logging
import time
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any
from uuid import uuid4
import click
import sqlalchemy as sa
import tqdm
from flask import Flask, current_app
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.helper import marketplace
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
from core.plugin.impl.plugin import PluginInstaller
from core.tools.entities.tool_entities import ToolProviderType
from extensions.ext_database import db
from models.account import Tenant
from models.model import App, AppMode, AppModelConfig
from models.provider_ids import ModelProviderID, ToolProviderID
from models.tools import BuiltinToolProvider
from models.workflow import Workflow
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
excluded_providers = ["time", "audio", "code", "webscraper"]
class PluginMigration:
@classmethod
def extract_plugins(cls, filepath: str, workers: int):
"""
Migrate plugin.
"""
from threading import Lock
click.echo(click.style("Migrating models/tools to new plugin Mechanism", fg="white"))
ended_at = datetime.datetime.now()
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
current_time = started_at
with Session(db.engine) as session:
total_tenant_count = session.query(Tenant.id).count()
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
handled_tenant_count = 0
file_lock = Lock()
counter_lock = Lock()
thread_pool = ThreadPoolExecutor(max_workers=workers)
def process_tenant(flask_app: Flask, tenant_id: str):
with flask_app.app_context():
nonlocal handled_tenant_count
try:
plugins = cls.extract_installed_plugin_ids(tenant_id)
# Use lock when writing to file
with file_lock:
with open(filepath, "a") as f:
f.write(json.dumps({"tenant_id": tenant_id, "plugins": plugins}) + "\n")
# Use lock when updating counter
with counter_lock:
nonlocal handled_tenant_count
handled_tenant_count += 1
click.echo(
click.style(
f"[{datetime.datetime.now()}] "
f"Processed {handled_tenant_count} tenants "
f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
f"{handled_tenant_count}/{total_tenant_count}",
fg="green",
)
)
except Exception:
logger.exception("Failed to process tenant %s", tenant_id)
futures = []
while current_time < ended_at:
click.echo(click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white"))
# Initial interval of 1 day, will be dynamically adjusted based on tenant count
interval = datetime.timedelta(days=1)
# Process tenants in this batch
with Session(db.engine) as session:
# Calculate tenant count in next batch with current interval
# Try different intervals until we find one with a reasonable tenant count
test_intervals = [
datetime.timedelta(days=1),
datetime.timedelta(hours=12),
datetime.timedelta(hours=6),
datetime.timedelta(hours=3),
datetime.timedelta(hours=1),
]
tenant_count = 0
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)
.where(Tenant.created_at.between(current_time, current_time + test_interval))
.count()
)
if tenant_count <= 100:
interval = test_interval
break
else:
# If all intervals have too many tenants, use minimum interval
interval = datetime.timedelta(hours=1)
# Adjust interval to target ~100 tenants per batch
if tenant_count > 0:
# Scale interval based on ratio to target count
interval = min(
datetime.timedelta(days=1), # Max 1 day
max(
datetime.timedelta(hours=1), # Min 1 hour
interval * (100 / tenant_count), # Scale to target 100
),
)
batch_end = min(current_time + interval, ended_at)
rs = (
session.query(Tenant.id)
.where(Tenant.created_at.between(current_time, batch_end))
.order_by(Tenant.created_at)
)
tenants = []
for row in rs:
tenant_id = str(row.id)
try:
tenants.append(tenant_id)
except Exception:
logger.exception("Failed to process tenant %s", tenant_id)
continue
futures.append(
thread_pool.submit(
process_tenant,
current_app._get_current_object(), # type: ignore
tenant_id,
)
)
current_time = batch_end
# wait for all threads to finish
for future in futures:
future.result()
@classmethod
def extract_installed_plugin_ids(cls, tenant_id: str) -> Sequence[str]:
"""
Extract installed plugin ids.
"""
tools = cls.extract_tool_tables(tenant_id)
models = cls.extract_model_tables(tenant_id)
workflows = cls.extract_workflow_tables(tenant_id)
apps = cls.extract_app_tables(tenant_id)
return list({*tools, *models, *workflows, *apps})
@classmethod
def extract_model_tables(cls, tenant_id: str) -> Sequence[str]:
"""
Extract model tables.
"""
models: list[str] = []
table_pairs = [
("providers", "provider_name"),
("provider_models", "provider_name"),
("provider_orders", "provider_name"),
("tenant_default_models", "provider_name"),
("tenant_preferred_model_providers", "provider_name"),
("provider_model_settings", "provider_name"),
("load_balancing_model_configs", "provider_name"),
]
for table, column in table_pairs:
models.extend(cls.extract_model_table(tenant_id, table, column))
# duplicate models
models = list(set(models))
return models
@classmethod
def extract_model_table(cls, tenant_id: str, table: str, column: str) -> Sequence[str]:
"""
Extract model table.
"""
with Session(db.engine) as session:
rs = session.execute(
sa.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
)
result = []
for row in rs:
provider_name = str(row[0])
result.append(ModelProviderID(provider_name).plugin_id)
return result
@classmethod
def extract_tool_tables(cls, tenant_id: str) -> Sequence[str]:
"""
Extract tool tables.
"""
with Session(db.engine) as session:
rs = session.query(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id).all()
result = []
for row in rs:
result.append(ToolProviderID(row.provider).plugin_id)
return result
@classmethod
def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]:
"""
Extract workflow tables, only ToolNode is required.
"""
with Session(db.engine) as session:
rs = session.query(Workflow).where(Workflow.tenant_id == tenant_id).all()
result = []
for row in rs:
graph = row.graph_dict
# get nodes
nodes = graph.get("nodes", [])
for node in nodes:
data = node.get("data", {})
if data.get("type") == "tool":
provider_name = data.get("provider_name")
provider_type = data.get("provider_type")
if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN:
result.append(ToolProviderID(provider_name).plugin_id)
return result
@classmethod
def extract_app_tables(cls, tenant_id: str) -> Sequence[str]:
"""
Extract app tables.
"""
with Session(db.engine) as session:
apps = session.query(App).where(App.tenant_id == tenant_id).all()
if not apps:
return []
agent_app_model_config_ids = [
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT
]
rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
result = []
for row in rs:
agent_config = row.agent_mode_dict
if "tools" in agent_config and isinstance(agent_config["tools"], list):
for tool in agent_config["tools"]:
if isinstance(tool, dict):
try:
tool_entity = AgentToolEntity.model_validate(tool)
if (
tool_entity.provider_type == ToolProviderType.BUILT_IN
and tool_entity.provider_id not in excluded_providers
):
result.append(ToolProviderID(tool_entity.provider_id).plugin_id)
except Exception:
logger.exception("Failed to process tool %s", tool)
continue
return result
@classmethod
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> str | None:
"""
Fetch plugin unique identifier using plugin id.
"""
plugin_manifest = marketplace.batch_fetch_plugin_manifests([plugin_id])
if not plugin_manifest:
return None
return plugin_manifest[0].latest_package_identifier
@classmethod
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str):
"""
Extract unique plugins.
"""
Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins)))
@classmethod
def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]:
plugins: dict[str, str] = {}
plugin_ids = []
plugin_not_exist = []
logger.info("Extracting unique plugins from %s", extracted_plugins)
with open(extracted_plugins) as f:
for line in f:
data = json.loads(line)
new_plugin_ids = data.get("plugins", [])
for plugin_id in new_plugin_ids:
if plugin_id not in plugin_ids:
plugin_ids.append(plugin_id)
def fetch_plugin(plugin_id):
try:
unique_identifier = cls._fetch_plugin_unique_identifier(plugin_id)
if unique_identifier:
plugins[plugin_id] = unique_identifier
else:
plugin_not_exist.append(plugin_id)
except Exception:
logger.exception("Failed to fetch plugin unique identifier for %s", plugin_id)
plugin_not_exist.append(plugin_id)
with ThreadPoolExecutor(max_workers=10) as executor:
list(tqdm.tqdm(executor.map(fetch_plugin, plugin_ids), total=len(plugin_ids)))
return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
@classmethod
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100):
"""
Install plugins.
"""
manager = PluginInstaller()
plugins = cls.extract_unique_plugins(extracted_plugins)
not_installed = []
plugin_install_failed = []
# use a fake tenant id to install all the plugins
fake_tenant_id = uuid4().hex
logger.info("Installing %s plugin instances for fake tenant %s", len(plugins["plugins"]), fake_tenant_id)
thread_pool = ThreadPoolExecutor(max_workers=workers)
response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
if response.get("failed"):
plugin_install_failed.extend(response.get("failed", []))
def install(tenant_id: str, plugin_ids: list[str]):
logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
# fetch plugin already installed
installed_plugins = manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
# at most 64 plugins one batch
for i in range(0, len(plugin_ids), 64):
batch_plugin_ids = plugin_ids[i : i + 64]
batch_plugin_identifiers = [
plugins["plugins"][plugin_id]
for plugin_id in batch_plugin_ids
if plugin_id not in installed_plugins_ids and plugin_id in plugins["plugins"]
]
manager.install_from_identifiers(
tenant_id,
batch_plugin_identifiers,
PluginInstallationSource.Marketplace,
metas=[
{
"plugin_unique_identifier": identifier,
}
for identifier in batch_plugin_identifiers
],
)
with open(extracted_plugins) as f:
"""
Read line by line, and install plugins for each tenant.
"""
for line in f:
data = json.loads(line)
tenant_id = data.get("tenant_id")
plugin_ids = data.get("plugins", [])
current_not_installed = {
"tenant_id": tenant_id,
"plugin_not_exist": [],
}
# get plugin unique identifier
for plugin_id in plugin_ids:
unique_identifier = plugins.get(plugin_id)
if unique_identifier:
current_not_installed["plugin_not_exist"].append(plugin_id)
if current_not_installed["plugin_not_exist"]:
not_installed.append(current_not_installed)
thread_pool.submit(install, tenant_id, plugin_ids)
thread_pool.shutdown(wait=True)
logger.info("Uninstall plugins")
# get installation
try:
installation = manager.list_plugins(fake_tenant_id)
while installation:
for plugin in installation:
manager.uninstall(fake_tenant_id, plugin.installation_id)
installation = manager.list_plugins(fake_tenant_id)
except Exception:
logger.exception("Failed to get installation for tenant %s", fake_tenant_id)
Path(output_file).write_text(
json.dumps(
{
"not_installed": not_installed,
"plugin_install_failed": plugin_install_failed,
}
)
)
@classmethod
def install_rag_pipeline_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
"""
Install rag pipeline plugins.
"""
manager = PluginInstaller()
plugins = cls.extract_unique_plugins(extracted_plugins)
plugin_install_failed = []
# use a fake tenant id to install all the plugins
fake_tenant_id = uuid4().hex
logger.info("Installing %s plugin instances for fake tenant %s", len(plugins["plugins"]), fake_tenant_id)
thread_pool = ThreadPoolExecutor(max_workers=workers)
response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
if response.get("failed"):
plugin_install_failed.extend(response.get("failed", []))
def install(
tenant_id: str, plugin_ids: dict[str, str], total_success_tenant: int, total_failed_tenant: int
) -> None:
logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
try:
# fetch plugin already installed
installed_plugins = manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
# at most 64 plugins one batch
for i in range(0, len(plugin_ids), 64):
batch_plugin_ids = list(plugin_ids.keys())[i : i + 64]
batch_plugin_identifiers = [
plugin_ids[plugin_id]
for plugin_id in batch_plugin_ids
if plugin_id not in installed_plugins_ids and plugin_id in plugin_ids
]
PluginService.install_from_marketplace_pkg(tenant_id, batch_plugin_identifiers)
total_success_tenant += 1
except Exception:
logger.exception("Failed to install plugins for tenant %s", tenant_id)
total_failed_tenant += 1
page = 1
total_success_tenant = 0
total_failed_tenant = 0
while True:
# paginate
tenants = db.paginate(sa.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
if tenants.items is None or len(tenants.items) == 0:
break
for tenant in tenants:
tenant_id = tenant.id
# get plugin unique identifier
thread_pool.submit(
install,
tenant_id,
plugins.get("plugins", {}),
total_success_tenant,
total_failed_tenant,
)
page += 1
thread_pool.shutdown(wait=True)
# uninstall all the plugins for fake tenant
try:
installation = manager.list_plugins(fake_tenant_id)
while installation:
for plugin in installation:
manager.uninstall(fake_tenant_id, plugin.installation_id)
installation = manager.list_plugins(fake_tenant_id)
except Exception:
logger.exception("Failed to get installation for tenant %s", fake_tenant_id)
Path(output_file).write_text(
json.dumps(
{
"total_success_tenant": total_success_tenant,
"total_failed_tenant": total_failed_tenant,
"plugin_install_failed": plugin_install_failed,
}
)
)
@classmethod
def handle_plugin_instance_install(
cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]
) -> Mapping[str, Any]:
"""
Install plugins for a tenant.
"""
manager = PluginInstaller()
# download all the plugins and upload
thread_pool = ThreadPoolExecutor(max_workers=10)
futures = []
for plugin_id, plugin_identifier in plugin_identifiers_map.items():
def download_and_upload(tenant_id, plugin_id, plugin_identifier):
plugin_package = marketplace.download_plugin_pkg(plugin_identifier)
if not plugin_package:
raise Exception(f"Failed to download plugin {plugin_identifier}")
# upload
manager.upload_pkg(tenant_id, plugin_package, verify_signature=True)
futures.append(thread_pool.submit(download_and_upload, tenant_id, plugin_id, plugin_identifier))
# Wait for all downloads to complete
for future in futures:
future.result() # This will raise any exceptions that occurred
thread_pool.shutdown(wait=True)
success = []
failed = []
reverse_map = {v: k for k, v in plugin_identifiers_map.items()}
# at most 8 plugins one batch
for i in range(0, len(plugin_identifiers_map), 8):
batch_plugin_ids = list(plugin_identifiers_map.keys())[i : i + 8]
batch_plugin_identifiers = [plugin_identifiers_map[plugin_id] for plugin_id in batch_plugin_ids]
try:
response = manager.install_from_identifiers(
tenant_id=tenant_id,
identifiers=batch_plugin_identifiers,
source=PluginInstallationSource.Marketplace,
metas=[
{
"plugin_unique_identifier": identifier,
}
for identifier in batch_plugin_identifiers
],
)
except Exception:
# add to failed
failed.extend(batch_plugin_identifiers)
continue
if response.all_installed:
success.extend(batch_plugin_identifiers)
continue
task_id = response.task_id
done = False
while not done:
status = manager.fetch_plugin_installation_task(tenant_id, task_id)
if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
for plugin in status.plugins:
if plugin.status == PluginInstallTaskStatus.Success:
success.append(reverse_map[plugin.plugin_unique_identifier])
else:
failed.append(reverse_map[plugin.plugin_unique_identifier])
logger.error(
"Failed to install plugin %s, error: %s",
plugin.plugin_unique_identifier,
plugin.message,
)
done = True
else:
time.sleep(1)
return {"success": success, "failed": failed}

View File

@@ -0,0 +1,107 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from sqlalchemy.orm import Session
from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.dynamic_select import DynamicSelectClient
from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_tool_provider_encrypter
from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity
from core.trigger.entities.entities import SubscriptionBuilder
from extensions.ext_database import db
from models.tools import BuiltinToolProvider
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
class PluginParameterService:
@staticmethod
def get_dynamic_select_options(
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
action: str,
parameter: str,
credential_id: str | None,
provider_type: Literal["tool", "trigger"],
) -> Sequence[PluginParameterOption]:
"""
Get dynamic select options for a plugin parameter.
Args:
tenant_id: The tenant ID.
plugin_id: The plugin ID.
provider: The provider name.
action: The action name.
parameter: The parameter name.
"""
credentials: Mapping[str, Any] = {}
credential_type: str = CredentialType.UNAUTHORIZED.value
match provider_type:
case "tool":
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
# init tool configuration
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
controller=provider_controller,
)
# check if credentials are required
if not provider_controller.need_credentials:
credentials = {}
else:
# fetch credentials from db
with Session(db.engine) as session:
if credential_id:
db_record = (
session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.id == credential_id,
)
.first()
)
else:
db_record = (
session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if db_record is None:
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
credentials = encrypter.decrypt(db_record.credentials)
credential_type = db_record.credential_type
case "trigger":
subscription: TriggerProviderSubscriptionApiEntity | SubscriptionBuilder | None
if credential_id:
subscription = TriggerSubscriptionBuilderService.get_subscription_builder(credential_id)
if not subscription:
trigger_subscription = TriggerProviderService.get_subscription_by_id(tenant_id, credential_id)
subscription = trigger_subscription.to_api_entity() if trigger_subscription else None
else:
trigger_subscription = TriggerProviderService.get_subscription_by_id(tenant_id)
subscription = trigger_subscription.to_api_entity() if trigger_subscription else None
if subscription is None:
raise ValueError(f"Subscription {credential_id} not found")
credentials = subscription.credentials
credential_type = subscription.credential_type or CredentialType.UNAUTHORIZED
return (
DynamicSelectClient()
.fetch_dynamic_select_options(
tenant_id, user_id, plugin_id, provider, action, credentials, credential_type, parameter
)
.options
)

View File

@@ -0,0 +1,34 @@
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.account import TenantPluginPermission
class PluginPermissionService:
@staticmethod
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
with Session(db.engine) as session:
return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
@staticmethod
def change_permission(
tenant_id: str,
install_permission: TenantPluginPermission.InstallPermission,
debug_permission: TenantPluginPermission.DebugPermission,
):
with Session(db.engine) as session:
permission = (
session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
)
if not permission:
permission = TenantPluginPermission(
tenant_id=tenant_id, install_permission=install_permission, debug_permission=debug_permission
)
session.add(permission)
else:
permission.install_permission = install_permission
permission.debug_permission = debug_permission
session.commit()
return True

View File

@@ -0,0 +1,525 @@
import logging
from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from pydantic import BaseModel
from yarl import URL
from configs import dify_config
from core.helper import marketplace
from core.helper.download import download_with_size_limit
from core.helper.marketplace import download_plugin_pkg
from core.plugin.entities.bundle import PluginBundleDependency
from core.plugin.entities.plugin import (
PluginDeclaration,
PluginEntity,
PluginInstallation,
PluginInstallationSource,
)
from core.plugin.entities.plugin_daemon import (
PluginDecodeResponse,
PluginInstallTask,
PluginListResponse,
PluginVerification,
)
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_redis import redis_client
from models.provider_ids import GenericProviderID
from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import FeatureService, PluginInstallationScope
logger = logging.getLogger(__name__)
class PluginService:
class LatestPluginCache(BaseModel):
plugin_id: str
version: str
unique_identifier: str
status: str
deprecated_reason: str
alternative_plugin_id: str
REDIS_KEY_PREFIX = "plugin_service:latest_plugin:"
REDIS_TTL = 60 * 5 # 5 minutes
@staticmethod
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
"""
Fetch the latest plugin version
"""
result: dict[str, PluginService.LatestPluginCache | None] = {}
try:
cache_not_exists = []
# Try to get from Redis first
for plugin_id in plugin_ids:
cached_data = redis_client.get(f"{PluginService.REDIS_KEY_PREFIX}{plugin_id}")
if cached_data:
result[plugin_id] = PluginService.LatestPluginCache.model_validate_json(cached_data)
else:
cache_not_exists.append(plugin_id)
if cache_not_exists:
manifests = {
manifest.plugin_id: manifest
for manifest in marketplace.batch_fetch_plugin_manifests(cache_not_exists)
}
for plugin_id, manifest in manifests.items():
latest_plugin = PluginService.LatestPluginCache(
plugin_id=plugin_id,
version=manifest.latest_version,
unique_identifier=manifest.latest_package_identifier,
status=manifest.status,
deprecated_reason=manifest.deprecated_reason,
alternative_plugin_id=manifest.alternative_plugin_id,
)
# Store in Redis
redis_client.setex(
f"{PluginService.REDIS_KEY_PREFIX}{plugin_id}",
PluginService.REDIS_TTL,
latest_plugin.model_dump_json(),
)
result[plugin_id] = latest_plugin
# pop plugin_id from cache_not_exists
cache_not_exists.remove(plugin_id)
for plugin_id in cache_not_exists:
result[plugin_id] = None
return result
except Exception:
logger.exception("failed to fetch latest plugin version")
return result
@staticmethod
def _check_marketplace_only_permission():
"""
Check if the marketplace only permission is enabled
"""
features = FeatureService.get_system_features()
if features.plugin_installation_permission.restrict_to_marketplace_only:
raise PluginInstallationForbiddenError("Plugin installation is restricted to marketplace only")
@staticmethod
def _check_plugin_installation_scope(plugin_verification: PluginVerification | None):
"""
Check the plugin installation scope
"""
features = FeatureService.get_system_features()
match features.plugin_installation_permission.plugin_installation_scope:
case PluginInstallationScope.OFFICIAL_ONLY:
if (
plugin_verification is None
or plugin_verification.authorized_category != PluginVerification.AuthorizedCategory.Langgenius
):
raise PluginInstallationForbiddenError("Plugin installation is restricted to official only")
case PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS:
if plugin_verification is None or plugin_verification.authorized_category not in [
PluginVerification.AuthorizedCategory.Langgenius,
PluginVerification.AuthorizedCategory.Partner,
]:
raise PluginInstallationForbiddenError(
"Plugin installation is restricted to official and specific partners"
)
case PluginInstallationScope.NONE:
raise PluginInstallationForbiddenError("Installing plugins is not allowed")
case PluginInstallationScope.ALL:
pass
@staticmethod
def get_debugging_key(tenant_id: str) -> str:
"""
get the debugging key of the tenant
"""
manager = PluginDebuggingClient()
return manager.get_debugging_key(tenant_id)
@staticmethod
def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
"""
List the latest versions of the plugins
"""
return PluginService.fetch_latest_plugin_version(plugin_ids)
@staticmethod
def list(tenant_id: str) -> list[PluginEntity]:
"""
list all plugins of the tenant
"""
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
return plugins
@staticmethod
def list_with_total(tenant_id: str, page: int, page_size: int) -> PluginListResponse:
"""
list all plugins of the tenant
"""
manager = PluginInstaller()
plugins = manager.list_plugins_with_total(tenant_id, page, page_size)
return plugins
@staticmethod
def list_installations_from_ids(tenant_id: str, ids: Sequence[str]) -> Sequence[PluginInstallation]:
"""
List plugin installations from ids
"""
manager = PluginInstaller()
return manager.fetch_plugin_installation_by_ids(tenant_id, ids)
@classmethod
def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
url_prefix = (
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
)
return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
@staticmethod
def get_asset(tenant_id: str, asset_file: str) -> tuple[bytes, str]:
"""
get the asset file of the plugin
"""
manager = PluginAssetManager()
# guess mime type
mime_type, _ = guess_type(asset_file)
return manager.fetch_asset(tenant_id, asset_file), mime_type or "application/octet-stream"
@staticmethod
def extract_asset(tenant_id: str, plugin_unique_identifier: str, file_name: str) -> bytes:
manager = PluginAssetManager()
return manager.extract_asset(tenant_id, plugin_unique_identifier, file_name)
@staticmethod
def check_plugin_unique_identifier(tenant_id: str, plugin_unique_identifier: str) -> bool:
"""
check if the plugin unique identifier is already installed by other tenant
"""
manager = PluginInstaller()
return manager.fetch_plugin_by_identifier(tenant_id, plugin_unique_identifier)
@staticmethod
def fetch_plugin_manifest(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
"""
Fetch plugin manifest
"""
manager = PluginInstaller()
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
@staticmethod
def is_plugin_verified(tenant_id: str, plugin_unique_identifier: str) -> bool:
"""
Check if the plugin is verified
"""
manager = PluginInstaller()
try:
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier).verified
except Exception:
return False
@staticmethod
def fetch_install_tasks(tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]:
"""
Fetch plugin installation tasks
"""
manager = PluginInstaller()
return manager.fetch_plugin_installation_tasks(tenant_id, page, page_size)
@staticmethod
def fetch_install_task(tenant_id: str, task_id: str) -> PluginInstallTask:
manager = PluginInstaller()
return manager.fetch_plugin_installation_task(tenant_id, task_id)
@staticmethod
def delete_install_task(tenant_id: str, task_id: str) -> bool:
"""
Delete a plugin installation task
"""
manager = PluginInstaller()
return manager.delete_plugin_installation_task(tenant_id, task_id)
@staticmethod
def delete_all_install_task_items(
tenant_id: str,
) -> bool:
"""
Delete all plugin installation task items
"""
manager = PluginInstaller()
return manager.delete_all_plugin_installation_task_items(tenant_id)
@staticmethod
def delete_install_task_item(tenant_id: str, task_id: str, identifier: str) -> bool:
"""
Delete a plugin installation task item
"""
manager = PluginInstaller()
return manager.delete_plugin_installation_task_item(tenant_id, task_id, identifier)
@staticmethod
def upgrade_plugin_with_marketplace(
tenant_id: str, original_plugin_unique_identifier: str, new_plugin_unique_identifier: str
):
"""
Upgrade plugin with marketplace
"""
if not dify_config.MARKETPLACE_ENABLED:
raise ValueError("marketplace is not enabled")
if original_plugin_unique_identifier == new_plugin_unique_identifier:
raise ValueError("you should not upgrade plugin with the same plugin")
# check if plugin pkg is already downloaded
manager = PluginInstaller()
features = FeatureService.get_system_features()
try:
manager.fetch_plugin_manifest(tenant_id, new_plugin_unique_identifier)
# already downloaded, skip, and record install event
marketplace.record_install_plugin_event(new_plugin_unique_identifier)
except Exception:
# plugin not installed, download and upload pkg
pkg = download_plugin_pkg(new_plugin_unique_identifier)
response = manager.upload_pkg(
tenant_id,
pkg,
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
)
# check if the plugin is available to install
PluginService._check_plugin_installation_scope(response.verification)
return manager.upgrade_plugin(
tenant_id,
original_plugin_unique_identifier,
new_plugin_unique_identifier,
PluginInstallationSource.Marketplace,
{
"plugin_unique_identifier": new_plugin_unique_identifier,
},
)
@staticmethod
def upgrade_plugin_with_github(
tenant_id: str,
original_plugin_unique_identifier: str,
new_plugin_unique_identifier: str,
repo: str,
version: str,
package: str,
):
"""
Upgrade plugin with github
"""
PluginService._check_marketplace_only_permission()
manager = PluginInstaller()
return manager.upgrade_plugin(
tenant_id,
original_plugin_unique_identifier,
new_plugin_unique_identifier,
PluginInstallationSource.Github,
{
"repo": repo,
"version": version,
"package": package,
},
)
@staticmethod
def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginDecodeResponse:
"""
Upload plugin package files
returns: plugin_unique_identifier
"""
PluginService._check_marketplace_only_permission()
manager = PluginInstaller()
features = FeatureService.get_system_features()
response = manager.upload_pkg(
tenant_id,
pkg,
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
)
PluginService._check_plugin_installation_scope(response.verification)
return response
@staticmethod
def upload_pkg_from_github(
tenant_id: str, repo: str, version: str, package: str, verify_signature: bool = False
) -> PluginDecodeResponse:
"""
Install plugin from github release package files,
returns plugin_unique_identifier
"""
PluginService._check_marketplace_only_permission()
pkg = download_with_size_limit(
f"https://github.com/{repo}/releases/download/{version}/{package}", dify_config.PLUGIN_MAX_PACKAGE_SIZE
)
features = FeatureService.get_system_features()
manager = PluginInstaller()
response = manager.upload_pkg(
tenant_id,
pkg,
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
)
PluginService._check_plugin_installation_scope(response.verification)
return response
@staticmethod
def upload_bundle(
tenant_id: str, bundle: bytes, verify_signature: bool = False
) -> Sequence[PluginBundleDependency]:
"""
Upload a plugin bundle and return the dependencies.
"""
manager = PluginInstaller()
PluginService._check_marketplace_only_permission()
return manager.upload_bundle(tenant_id, bundle, verify_signature)
@staticmethod
def install_from_local_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]):
PluginService._check_marketplace_only_permission()
manager = PluginInstaller()
for plugin_unique_identifier in plugin_unique_identifiers:
resp = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
PluginService._check_plugin_installation_scope(resp.verification)
return manager.install_from_identifiers(
tenant_id,
plugin_unique_identifiers,
PluginInstallationSource.Package,
[{}],
)
@staticmethod
def install_from_github(tenant_id: str, plugin_unique_identifier: str, repo: str, version: str, package: str):
"""
Install plugin from github release package files,
returns plugin_unique_identifier
"""
PluginService._check_marketplace_only_permission()
manager = PluginInstaller()
plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
PluginService._check_plugin_installation_scope(plugin_decode_response.verification)
return manager.install_from_identifiers(
tenant_id,
[plugin_unique_identifier],
PluginInstallationSource.Github,
[
{
"repo": repo,
"version": version,
"package": package,
}
],
)
@staticmethod
def fetch_marketplace_pkg(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
"""
Fetch marketplace package
"""
if not dify_config.MARKETPLACE_ENABLED:
raise ValueError("marketplace is not enabled")
features = FeatureService.get_system_features()
manager = PluginInstaller()
try:
declaration = manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
except Exception:
pkg = download_plugin_pkg(plugin_unique_identifier)
response = manager.upload_pkg(
tenant_id,
pkg,
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
)
# check if the plugin is available to install
PluginService._check_plugin_installation_scope(response.verification)
declaration = response.manifest
return declaration
@staticmethod
def install_from_marketplace_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]):
"""
Install plugin from marketplace package files,
returns installation task id
"""
if not dify_config.MARKETPLACE_ENABLED:
raise ValueError("marketplace is not enabled")
manager = PluginInstaller()
# collect actual plugin_unique_identifiers
actual_plugin_unique_identifiers = []
metas = []
features = FeatureService.get_system_features()
# check if already downloaded
for plugin_unique_identifier in plugin_unique_identifiers:
try:
manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
# check if the plugin is available to install
PluginService._check_plugin_installation_scope(plugin_decode_response.verification)
# already downloaded, skip
actual_plugin_unique_identifiers.append(plugin_unique_identifier)
metas.append({"plugin_unique_identifier": plugin_unique_identifier})
except Exception:
# plugin not installed, download and upload pkg
pkg = download_plugin_pkg(plugin_unique_identifier)
response = manager.upload_pkg(
tenant_id,
pkg,
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
)
# check if the plugin is available to install
PluginService._check_plugin_installation_scope(response.verification)
# use response plugin_unique_identifier
actual_plugin_unique_identifiers.append(response.unique_identifier)
metas.append({"plugin_unique_identifier": response.unique_identifier})
return manager.install_from_identifiers(
tenant_id,
actual_plugin_unique_identifiers,
PluginInstallationSource.Marketplace,
metas,
)
@staticmethod
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
manager = PluginInstaller()
return manager.uninstall(tenant_id, plugin_installation_id)
@staticmethod
def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
"""
Check if the tools exist
"""
manager = PluginInstaller()
return manager.check_tools_existence(tenant_id, provider_ids)
@staticmethod
def fetch_plugin_readme(tenant_id: str, plugin_unique_identifier: str, language: str) -> str:
"""
Fetch plugin readme
"""
manager = PluginInstaller()
return manager.fetch_plugin_readme(tenant_id, plugin_unique_identifier, language)

View File

@@ -0,0 +1,22 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel
class DatasourceNodeRunApiEntity(BaseModel):
pipeline_id: str
node_id: str
inputs: dict[str, Any]
datasource_type: str
credential_id: str | None = None
is_published: bool
class PipelineRunApiEntity(BaseModel):
inputs: Mapping[str, Any]
datasource_type: str
datasource_info_list: list[Mapping[str, Any]]
start_node_id: str
is_published: bool
response_mode: str

View File

@@ -0,0 +1,115 @@
from collections.abc import Mapping
from typing import Any, Union
from configs import dify_config
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from models.dataset import Document, Pipeline
from models.model import Account, App, EndUser
from models.workflow import Workflow
from services.rag_pipeline.rag_pipeline import RagPipelineService
class PipelineGenerateService:
@classmethod
def generate(
cls,
pipeline: Pipeline,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
"""
Pipeline Content Generate
:param pipeline: pipeline
:param user: user
:param args: args
:param invoke_from: invoke from
:param streaming: streaming
:return:
"""
try:
workflow = cls._get_workflow(pipeline, invoke_from)
if original_document_id := args.get("original_document_id"):
# update document status to waiting
cls.update_document_status(original_document_id)
return PipelineGenerator.convert_to_event_stream(
PipelineGenerator().generate(
pipeline=pipeline,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
call_depth=0,
workflow_thread_pool_id=None,
),
)
except Exception:
raise
@staticmethod
def _get_max_active_requests(app_model: App) -> int:
max_active_requests = app_model.max_active_requests
if max_active_requests is None:
max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS)
return max_active_requests
@classmethod
def generate_single_iteration(
cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True
):
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
return PipelineGenerator.convert_to_event_stream(
PipelineGenerator().single_iteration_generate(
pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
@classmethod
def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True):
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
return PipelineGenerator.convert_to_event_stream(
PipelineGenerator().single_loop_generate(
pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
@classmethod
def _get_workflow(cls, pipeline: Pipeline, invoke_from: InvokeFrom) -> Workflow:
"""
Get workflow
:param pipeline: pipeline
:param invoke_from: invoke from
:return:
"""
rag_pipeline_service = RagPipelineService()
if invoke_from == InvokeFrom.DEBUGGER:
# fetch draft workflow by app_model
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
else:
# fetch published workflow by app_model
workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not published")
return workflow
@classmethod
def update_document_status(cls, document_id: str):
"""
Update document status to waiting
:param document_id: document id
"""
document = db.session.query(Document).where(Document.id == document_id).first()
if document:
document.indexing_status = "waiting"
db.session.add(document)
db.session.commit()

View File

@@ -0,0 +1,63 @@
import json
from os import path
from pathlib import Path
from flask import current_app
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval pipeline template from built-in, the location is constants/pipeline_templates.json
"""
builtin_data: dict | None = None
def get_type(self) -> str:
return PipelineTemplateType.BUILTIN
def get_pipeline_templates(self, language: str) -> dict:
result = self.fetch_pipeline_templates_from_builtin(language)
return result
def get_pipeline_template_detail(self, template_id: str):
result = self.fetch_pipeline_template_detail_from_builtin(template_id)
return result
@classmethod
def _get_builtin_data(cls) -> dict:
"""
Get builtin data.
:return:
"""
if cls.builtin_data:
return cls.builtin_data
root_path = current_app.root_path
cls.builtin_data = json.loads(
Path(path.join(root_path, "constants", "pipeline_templates.json")).read_text(encoding="utf-8")
)
return cls.builtin_data or {}
@classmethod
def fetch_pipeline_templates_from_builtin(cls, language: str) -> dict:
"""
Fetch pipeline templates from builtin.
:param language: language
:return:
"""
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
return builtin_data.get("pipeline_templates", {}).get(language, {})
@classmethod
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from builtin.
:param template_id: Template ID
:return:
"""
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
return builtin_data.get("pipeline_templates", {}).get(template_id)

View File

@@ -0,0 +1,80 @@
import yaml
from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.dataset import PipelineCustomizedTemplate
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval recommended app from database
"""
def get_pipeline_templates(self, language: str) -> dict:
_, current_tenant_id = current_account_with_tenant()
result = self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language)
return result
def get_pipeline_template_detail(self, template_id: str):
result = self.fetch_pipeline_template_detail_from_db(template_id)
return result
def get_type(self) -> str:
return PipelineTemplateType.CUSTOMIZED
@classmethod
def fetch_pipeline_templates_from_customized(cls, tenant_id: str, language: str) -> dict:
"""
Fetch pipeline templates from db.
:param tenant_id: tenant id
:param language: language
:return:
"""
pipeline_customized_templates = (
db.session.query(PipelineCustomizedTemplate)
.where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
.order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc())
.all()
)
recommended_pipelines_results = []
for pipeline_customized_template in pipeline_customized_templates:
recommended_pipeline_result = {
"id": pipeline_customized_template.id,
"name": pipeline_customized_template.name,
"description": pipeline_customized_template.description,
"icon": pipeline_customized_template.icon,
"position": pipeline_customized_template.position,
"chunk_structure": pipeline_customized_template.chunk_structure,
}
recommended_pipelines_results.append(recommended_pipeline_result)
return {"pipeline_templates": recommended_pipelines_results}
@classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from db.
:param template_id: Template ID
:return:
"""
pipeline_template = (
db.session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
)
if not pipeline_template:
return None
dsl_data = yaml.safe_load(pipeline_template.yaml_content)
graph_data = dsl_data.get("workflow", {}).get("graph", {})
return {
"id": pipeline_template.id,
"name": pipeline_template.name,
"icon_info": pipeline_template.icon,
"description": pipeline_template.description,
"chunk_structure": pipeline_template.chunk_structure,
"export_data": pipeline_template.yaml_content,
"graph": graph_data,
"created_by": pipeline_template.created_user_name,
}

View File

@@ -0,0 +1,77 @@
import yaml
from extensions.ext_database import db
from models.dataset import PipelineBuiltInTemplate
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval pipeline template from database
"""
def get_pipeline_templates(self, language: str) -> dict:
result = self.fetch_pipeline_templates_from_db(language)
return result
def get_pipeline_template_detail(self, template_id: str):
result = self.fetch_pipeline_template_detail_from_db(template_id)
return result
def get_type(self) -> str:
return PipelineTemplateType.DATABASE
@classmethod
def fetch_pipeline_templates_from_db(cls, language: str) -> dict:
"""
Fetch pipeline templates from db.
:param language: language
:return:
"""
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = (
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language).all()
)
recommended_pipelines_results = []
for pipeline_built_in_template in pipeline_built_in_templates:
recommended_pipeline_result = {
"id": pipeline_built_in_template.id,
"name": pipeline_built_in_template.name,
"description": pipeline_built_in_template.description,
"icon": pipeline_built_in_template.icon,
"copyright": pipeline_built_in_template.copyright,
"privacy_policy": pipeline_built_in_template.privacy_policy,
"position": pipeline_built_in_template.position,
"chunk_structure": pipeline_built_in_template.chunk_structure,
}
recommended_pipelines_results.append(recommended_pipeline_result)
return {"pipeline_templates": recommended_pipelines_results}
@classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from db.
:param pipeline_id: Pipeline ID
:return:
"""
# is in public recommended list
pipeline_template = (
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.id == template_id).first()
)
if not pipeline_template:
return None
dsl_data = yaml.safe_load(pipeline_template.yaml_content)
graph_data = dsl_data.get("workflow", {}).get("graph", {})
return {
"id": pipeline_template.id,
"name": pipeline_template.name,
"icon_info": pipeline_template.icon,
"description": pipeline_template.description,
"chunk_structure": pipeline_template.chunk_structure,
"export_data": pipeline_template.yaml_content,
"graph": graph_data,
}

View File

@@ -0,0 +1,17 @@
from abc import ABC, abstractmethod
class PipelineTemplateRetrievalBase(ABC):
"""Interface for pipeline template retrieval."""
@abstractmethod
def get_pipeline_templates(self, language: str) -> dict:
raise NotImplementedError
@abstractmethod
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
raise NotImplementedError
@abstractmethod
def get_type(self) -> str:
raise NotImplementedError

View File

@@ -0,0 +1,26 @@
from services.rag_pipeline.pipeline_template.built_in.built_in_retrieval import BuiltInPipelineTemplateRetrieval
from services.rag_pipeline.pipeline_template.customized.customized_retrieval import CustomizedPipelineTemplateRetrieval
from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
from services.rag_pipeline.pipeline_template.remote.remote_retrieval import RemotePipelineTemplateRetrieval
class PipelineTemplateRetrievalFactory:
@staticmethod
def get_pipeline_template_factory(mode: str) -> type[PipelineTemplateRetrievalBase]:
match mode:
case PipelineTemplateType.REMOTE:
return RemotePipelineTemplateRetrieval
case PipelineTemplateType.CUSTOMIZED:
return CustomizedPipelineTemplateRetrieval
case PipelineTemplateType.DATABASE:
return DatabasePipelineTemplateRetrieval
case PipelineTemplateType.BUILTIN:
return BuiltInPipelineTemplateRetrieval
case _:
raise ValueError(f"invalid fetch recommended apps mode: {mode}")
@staticmethod
def get_built_in_pipeline_template_retrieval():
return BuiltInPipelineTemplateRetrieval

View File

@@ -0,0 +1,8 @@
from enum import StrEnum
class PipelineTemplateType(StrEnum):
REMOTE = "remote"
DATABASE = "database"
CUSTOMIZED = "customized"
BUILTIN = "builtin"

View File

@@ -0,0 +1,67 @@
import logging
import httpx
from configs import dify_config
from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
logger = logging.getLogger(__name__)
class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval recommended app from dify official
"""
def get_pipeline_template_detail(self, template_id: str):
try:
result = self.fetch_pipeline_template_detail_from_dify_official(template_id)
except Exception as e:
logger.warning("fetch recommended app detail from dify official failed: %r, switch to database.", e)
result = DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id)
return result
def get_pipeline_templates(self, language: str) -> dict:
try:
result = self.fetch_pipeline_templates_from_dify_official(language)
except Exception as e:
logger.warning("fetch pipeline templates from dify official failed: %r, switch to database.", e)
result = DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language)
return result
def get_type(self) -> str:
return PipelineTemplateType.REMOTE
@classmethod
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from dify official.
:param template_id: Pipeline ID
:return:
"""
domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN
url = f"{domain}/pipeline-templates/{template_id}"
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
if response.status_code != 200:
return None
data: dict = response.json()
return data
@classmethod
def fetch_pipeline_templates_from_dify_official(cls, language: str) -> dict:
"""
Fetch pipeline templates from dify official.
:param language: language
:return:
"""
domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN
url = f"{domain}/pipeline-templates?language={language}"
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
if response.status_code != 200:
raise ValueError(f"fetch pipeline templates failed, status code: {response.status_code}")
result: dict = response.json()
return result

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,945 @@
import base64
import hashlib
import json
import logging
import uuid
from collections.abc import Mapping
from datetime import UTC, datetime
from enum import StrEnum
from typing import cast
from urllib.parse import urlparse
from uuid import uuid4
import yaml # type: ignore
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from flask_login import current_user
from packaging import version
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.helper import ssrf_proxy
from core.helper.name_generator import generate_incremental_name
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import PluginDependency
from core.workflow.enums import NodeType
from core.workflow.nodes.datasource.entities import DatasourceNodeData
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.tool.entities import ToolNodeData
from extensions.ext_redis import redis_client
from factories import variable_factory
from models import Account
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
from models.workflow import Workflow, WorkflowType
from services.entities.knowledge_entities.rag_pipeline_entities import (
IconInfo,
KnowledgeConfiguration,
RagPipelineDatasetCreateEntity,
)
from services.plugin.dependencies_analysis import DependenciesAnalysisService
logger = logging.getLogger(__name__)
IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.1.0"
class ImportMode(StrEnum):
YAML_CONTENT = "yaml-content"
YAML_URL = "yaml-url"
class ImportStatus(StrEnum):
COMPLETED = "completed"
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
PENDING = "pending"
FAILED = "failed"
class RagPipelineImportInfo(BaseModel):
id: str
status: ImportStatus
pipeline_id: str | None = None
current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = ""
error: str = ""
dataset_id: str | None = None
class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
def _check_version_compatibility(imported_version: str) -> ImportStatus:
"""Determine import status based on version comparison"""
try:
current_ver = version.parse(CURRENT_DSL_VERSION)
imported_ver = version.parse(imported_version)
except version.InvalidVersion:
return ImportStatus.FAILED
# If imported version is newer than current, always return PENDING
if imported_ver > current_ver:
return ImportStatus.PENDING
# If imported version is older than current's major, return PENDING
if imported_ver.major < current_ver.major:
return ImportStatus.PENDING
# If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS
if imported_ver.minor < current_ver.minor:
return ImportStatus.COMPLETED_WITH_WARNINGS
# If imported version equals or is older than current's micro, return COMPLETED
return ImportStatus.COMPLETED
class RagPipelinePendingData(BaseModel):
import_mode: str
yaml_content: str
pipeline_id: str | None
class CheckDependenciesPendingData(BaseModel):
dependencies: list[PluginDependency]
pipeline_id: str | None
class RagPipelineDslService:
def __init__(self, session: Session):
self._session = session
def import_rag_pipeline(
self,
*,
account: Account,
import_mode: str,
yaml_content: str | None = None,
yaml_url: str | None = None,
pipeline_id: str | None = None,
dataset: Dataset | None = None,
dataset_name: str | None = None,
icon_info: IconInfo | None = None,
) -> RagPipelineImportInfo:
"""Import an app from YAML content or URL."""
import_id = str(uuid.uuid4())
# Validate import mode
try:
mode = ImportMode(import_mode)
except ValueError:
raise ValueError(f"Invalid import_mode: {import_mode}")
# Get YAML content
content: str = ""
if mode == ImportMode.YAML_URL:
if not yaml_url:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="yaml_url is required when import_mode is yaml-url",
)
try:
parsed_url = urlparse(yaml_url)
if (
parsed_url.scheme == "https"
and parsed_url.netloc == "github.com"
and parsed_url.path.endswith((".yml", ".yaml"))
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/")
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
response.raise_for_status()
content = response.content.decode()
if len(content) > DSL_MAX_SIZE:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="File size exceeds the limit of 10MB",
)
if not content:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Empty content from url",
)
except Exception as e:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Error fetching YAML from URL: {str(e)}",
)
elif mode == ImportMode.YAML_CONTENT:
if not yaml_content:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="yaml_content is required when import_mode is yaml-content",
)
content = yaml_content
# Process YAML content
try:
# Parse YAML to validate format
data = yaml.safe_load(content)
if not isinstance(data, dict):
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid YAML format: content must be a mapping",
)
# Validate and fix DSL version
if not data.get("version"):
data["version"] = "0.1.0"
if not data.get("kind") or data.get("kind") != "rag_pipeline":
data["kind"] = "rag_pipeline"
imported_version = data.get("version", "0.1.0")
# check if imported_version is a float-like string
if not isinstance(imported_version, str):
raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}")
status = _check_version_compatibility(imported_version)
# Extract app data
pipeline_data = data.get("rag_pipeline")
if not pipeline_data:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Missing rag_pipeline data in YAML content",
)
# If app_id is provided, check if it exists
pipeline = None
if pipeline_id:
stmt = select(Pipeline).where(
Pipeline.id == pipeline_id,
Pipeline.tenant_id == account.current_tenant_id,
)
pipeline = self._session.scalar(stmt)
if not pipeline:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Pipeline not found",
)
dataset = pipeline.retrieve_dataset(session=self._session)
# If major version mismatch, store import info in Redis
if status == ImportStatus.PENDING:
pending_data = RagPipelinePendingData(
import_mode=import_mode,
yaml_content=content,
pipeline_id=pipeline_id,
)
redis_client.setex(
f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}",
IMPORT_INFO_REDIS_EXPIRY,
pending_data.model_dump_json(),
)
return RagPipelineImportInfo(
id=import_id,
status=status,
pipeline_id=pipeline_id,
imported_dsl_version=imported_version,
)
# Extract dependencies
dependencies = data.get("dependencies", [])
check_dependencies_pending_data = None
if dependencies:
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
# Create or update pipeline
pipeline = self._create_or_update_pipeline(
pipeline=pipeline,
data=data,
account=account,
dependencies=check_dependencies_pending_data,
)
# create dataset
name = pipeline.name or "Untitled"
description = pipeline.description
if icon_info:
icon_type = icon_info.icon_type
icon = icon_info.icon
icon_background = icon_info.icon_background
icon_url = icon_info.icon_url
else:
icon_type = data.get("rag_pipeline", {}).get("icon_type")
icon = data.get("rag_pipeline", {}).get("icon")
icon_background = data.get("rag_pipeline", {}).get("icon_background")
icon_url = data.get("rag_pipeline", {}).get("icon_url")
workflow = data.get("workflow", {})
graph = workflow.get("graph", {})
nodes = graph.get("nodes", [])
dataset_id = None
for node in nodes:
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
if (
dataset
and pipeline.is_published
and dataset.chunk_structure != knowledge_configuration.chunk_structure
):
raise ValueError("Chunk structure is not compatible with the published pipeline")
if not dataset:
datasets = self._session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all()
names = [dataset.name for dataset in datasets]
generate_name = generate_incremental_name(names, name)
dataset = Dataset(
tenant_id=account.current_tenant_id,
name=generate_name,
description=description,
icon_info={
"icon_type": icon_type,
"icon": icon,
"icon_background": icon_background,
"icon_url": icon_url,
},
indexing_technique=knowledge_configuration.indexing_technique,
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode="rag_pipeline",
chunk_structure=knowledge_configuration.chunk_structure,
)
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
.where(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
.first()
)
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=knowledge_configuration.embedding_model_provider,
model_name=knowledge_configuration.embedding_model,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
)
self._session.add(dataset_collection_binding)
self._session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id
self._session.add(dataset)
self._session.commit()
dataset_id = dataset.id
if not dataset_id:
raise ValueError("DSL is not valid, please check the Knowledge Index node.")
return RagPipelineImportInfo(
id=import_id,
status=status,
pipeline_id=pipeline.id,
dataset_id=dataset_id,
imported_dsl_version=imported_version,
)
except yaml.YAMLError as e:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Invalid YAML format: {str(e)}",
)
except Exception as e:
logger.exception("Failed to import app")
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)
def confirm_import(self, *, import_id: str, account: Account) -> RagPipelineImportInfo:
"""
Confirm an import that requires confirmation
"""
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
pending_data = redis_client.get(redis_key)
if not pending_data:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Import information expired or does not exist",
)
try:
if not isinstance(pending_data, str | bytes):
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid import information",
)
pending_data = RagPipelinePendingData.model_validate_json(pending_data)
data = yaml.safe_load(pending_data.yaml_content)
pipeline = None
if pending_data.pipeline_id:
stmt = select(Pipeline).where(
Pipeline.id == pending_data.pipeline_id,
Pipeline.tenant_id == account.current_tenant_id,
)
pipeline = self._session.scalar(stmt)
# Create or update app
pipeline = self._create_or_update_pipeline(
pipeline=pipeline,
data=data,
account=account,
)
dataset = pipeline.retrieve_dataset(session=self._session)
# create dataset
name = pipeline.name
description = pipeline.description
icon_type = data.get("rag_pipeline", {}).get("icon_type")
icon = data.get("rag_pipeline", {}).get("icon")
icon_background = data.get("rag_pipeline", {}).get("icon_background")
icon_url = data.get("rag_pipeline", {}).get("icon_url")
workflow = data.get("workflow", {})
graph = workflow.get("graph", {})
nodes = graph.get("nodes", [])
dataset_id = None
for node in nodes:
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
if not dataset:
dataset = Dataset(
tenant_id=account.current_tenant_id,
name=name,
description=description,
icon_info={
"icon_type": icon_type,
"icon": icon,
"icon_background": icon_background,
"icon_url": icon_url,
},
indexing_technique=knowledge_configuration.indexing_technique,
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode="rag_pipeline",
chunk_structure=knowledge_configuration.chunk_structure,
)
else:
dataset.indexing_technique = knowledge_configuration.indexing_technique
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
dataset.runtime_mode = "rag_pipeline"
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
.where(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
.first()
)
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=knowledge_configuration.embedding_model_provider,
model_name=knowledge_configuration.embedding_model,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
)
self._session.add(dataset_collection_binding)
self._session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id
self._session.add(dataset)
self._session.commit()
dataset_id = dataset.id
if not dataset_id:
raise ValueError("DSL is not valid, please check the Knowledge Index node.")
# Delete import info from Redis
redis_client.delete(redis_key)
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.COMPLETED,
pipeline_id=pipeline.id,
dataset_id=dataset_id,
current_dsl_version=CURRENT_DSL_VERSION,
imported_dsl_version=data.get("version", "0.1.0"),
)
except Exception as e:
logger.exception("Error confirming import")
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)
def check_dependencies(
self,
*,
pipeline: Pipeline,
) -> CheckDependenciesResult:
"""Check dependencies"""
# Get dependencies from Redis
redis_key = f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}"
dependencies = redis_client.get(redis_key)
if not dependencies:
return CheckDependenciesResult()
# Extract dependencies
dependencies = CheckDependenciesPendingData.model_validate_json(dependencies)
# Get leaked dependencies
leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies(
tenant_id=pipeline.tenant_id, dependencies=dependencies.dependencies
)
return CheckDependenciesResult(
leaked_dependencies=leaked_dependencies,
)
def _create_or_update_pipeline(
self,
*,
pipeline: Pipeline | None,
data: dict,
account: Account,
dependencies: list[PluginDependency] | None = None,
) -> Pipeline:
"""Create a new app or update an existing one."""
if not account.current_tenant_id:
raise ValueError("Tenant id is required")
pipeline_data = data.get("rag_pipeline", {})
# Initialize pipeline based on mode
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
raise ValueError("Missing workflow data for rag pipeline")
environment_variables_list = workflow_data.get("environment_variables", [])
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = workflow_data.get("conversation_variables", [])
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
decrypted_id
for dataset_id in dataset_ids
if (
decrypted_id := self.decrypt_dataset_id(
encrypted_data=dataset_id,
tenant_id=account.current_tenant_id,
)
)
]
if pipeline:
# Update existing pipeline
pipeline.name = pipeline_data.get("name", pipeline.name)
pipeline.description = pipeline_data.get("description", pipeline.description)
pipeline.updated_by = account.id
else:
if account.current_tenant_id is None:
raise ValueError("Current tenant is not set")
# Create new app
pipeline = Pipeline(
tenant_id=account.current_tenant_id,
name=pipeline_data.get("name", ""),
description=pipeline_data.get("description", ""),
created_by=account.id,
updated_by=account.id,
)
pipeline.id = str(uuid4())
self._session.add(pipeline)
self._session.commit()
# save dependencies
if dependencies:
redis_client.setex(
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}",
IMPORT_INFO_REDIS_EXPIRY,
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
)
workflow = (
self._session.query(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
)
# create draft workflow if not found
if not workflow:
workflow = Workflow(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
features="{}",
type=WorkflowType.RAG_PIPELINE,
version="draft",
graph=json.dumps(graph),
created_by=account.id,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
rag_pipeline_variables=rag_pipeline_variables_list,
)
self._session.add(workflow)
self._session.flush()
pipeline.workflow_id = workflow.id
else:
workflow.graph = json.dumps(graph)
workflow.updated_by = account.id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
workflow.rag_pipeline_variables = rag_pipeline_variables_list
# commit db session changes
self._session.commit()
return pipeline
def export_rag_pipeline_dsl(self, pipeline: Pipeline, include_secret: bool = False) -> str:
"""
Export pipeline
:param pipeline: Pipeline instance
:param include_secret: Whether include secret variable
:return:
"""
dataset = pipeline.retrieve_dataset(session=self._session)
if not dataset:
raise ValueError("Missing dataset for rag pipeline")
icon_info = dataset.icon_info
export_data = {
"version": CURRENT_DSL_VERSION,
"kind": "rag_pipeline",
"rag_pipeline": {
"name": dataset.name,
"icon": icon_info.get("icon", "📙") if icon_info else "📙",
"icon_type": icon_info.get("icon_type", "emoji") if icon_info else "emoji",
"icon_background": icon_info.get("icon_background", "#FFEAD5") if icon_info else "#FFEAD5",
"icon_url": icon_info.get("icon_url") if icon_info else None,
"description": pipeline.description,
},
}
self._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret)
return yaml.dump(export_data, allow_unicode=True) # type: ignore
def _append_workflow_export_data(self, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None:
"""
Append workflow export data
:param export_data: export data
:param pipeline: Pipeline instance
"""
workflow = (
self._session.query(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
)
if not workflow:
raise ValueError("Missing draft workflow configuration, please check.")
workflow_dict = workflow.to_dict(include_secret=include_secret)
for node in workflow_dict.get("graph", {}).get("nodes", []):
node_data = node.get("data", {})
if not node_data:
continue
data_type = node_data.get("type", "")
if data_type == NodeType.KNOWLEDGE_RETRIEVAL:
dataset_ids = node_data.get("dataset_ids", [])
node["data"]["dataset_ids"] = [
self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
for dataset_id in dataset_ids
]
# filter credential id from tool node
if not include_secret and data_type == NodeType.TOOL:
node_data.pop("credential_id", None)
# filter credential id from agent node
if not include_secret and data_type == NodeType.AGENT:
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
tool.pop("credential_id", None)
export_data["workflow"] = workflow_dict
dependencies = self._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
for d in DependenciesAnalysisService.generate_dependencies(
tenant_id=pipeline.tenant_id, dependencies=dependencies
)
]
def _extract_dependencies_from_workflow(self, workflow: Workflow) -> list[str]:
"""
Extract dependencies from workflow
:param workflow: Workflow instance
:return: dependencies list format like ["langgenius/google"]
"""
graph = workflow.graph_dict
dependencies = self._extract_dependencies_from_workflow_graph(graph)
return dependencies
def _extract_dependencies_from_workflow_graph(self, graph: Mapping) -> list[str]:
"""
Extract dependencies from workflow graph
:param graph: Workflow graph
:return: dependencies list format like ["langgenius/google"]
"""
dependencies = []
for node in graph.get("nodes", []):
try:
typ = node.get("data", {}).get("type")
match typ:
case NodeType.TOOL:
tool_entity = ToolNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
)
case NodeType.DATASOURCE:
datasource_entity = DatasourceNodeData.model_validate(node["data"])
if datasource_entity.provider_type != "local_file":
dependencies.append(datasource_entity.plugin_id)
case NodeType.LLM:
llm_entity = LLMNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
)
case NodeType.QUESTION_CLASSIFIER:
question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
question_classifier_entity.model.provider
),
)
case NodeType.PARAMETER_EXTRACTOR:
parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
parameter_extractor_entity.model.provider
),
)
case NodeType.KNOWLEDGE_INDEX:
knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"])
if knowledge_index_entity.indexing_technique == "high_quality":
if knowledge_index_entity.embedding_model_provider:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
knowledge_index_entity.embedding_model_provider
),
)
if knowledge_index_entity.retrieval_model.reranking_mode == "reranking_model":
if knowledge_index_entity.retrieval_model.reranking_enable:
if (
knowledge_index_entity.retrieval_model.reranking_model
and knowledge_index_entity.retrieval_model.reranking_mode == "reranking_model"
):
if knowledge_index_entity.retrieval_model.reranking_model.reranking_provider_name:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
knowledge_index_entity.retrieval_model.reranking_model.reranking_provider_name
),
)
case NodeType.KNOWLEDGE_RETRIEVAL:
knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"])
if knowledge_retrieval_entity.retrieval_mode == "multiple":
if knowledge_retrieval_entity.multiple_retrieval_config:
if (
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
== "reranking_model"
):
if knowledge_retrieval_entity.multiple_retrieval_config.reranking_model:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
knowledge_retrieval_entity.multiple_retrieval_config.reranking_model.provider
),
)
elif (
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
== "weighted_score"
):
if knowledge_retrieval_entity.multiple_retrieval_config.weights:
vector_setting = (
knowledge_retrieval_entity.multiple_retrieval_config.weights.vector_setting
)
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
vector_setting.embedding_provider_name
),
)
elif knowledge_retrieval_entity.retrieval_mode == "single":
model_config = knowledge_retrieval_entity.single_retrieval_config
if model_config:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
model_config.model.provider
),
)
case _:
# TODO: Handle default case or unknown node types
pass
except Exception as e:
logger.exception("Error extracting node dependency", exc_info=e)
return dependencies
@classmethod
def _extract_dependencies_from_model_config(cls, model_config: Mapping) -> list[str]:
"""
Extract dependencies from model config
:param model_config: model config dict
:return: dependencies list format like ["langgenius/google"]
"""
dependencies = []
try:
# completion model
model_dict = model_config.get("model", {})
if model_dict:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(model_dict.get("provider", ""))
)
# reranking model
dataset_configs = model_config.get("dataset_configs", {})
if dataset_configs:
for dataset_config in dataset_configs.get("datasets", {}).get("datasets", []):
if dataset_config.get("reranking_model"):
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
dataset_config.get("reranking_model", {})
.get("reranking_provider_name", {})
.get("provider")
)
)
# tools
agent_configs = model_config.get("agent_mode", {})
if agent_configs:
for agent_config in agent_configs.get("tools", []):
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(agent_config.get("provider_id"))
)
except Exception as e:
logger.exception("Error extracting model config dependency", exc_info=e)
return dependencies
@classmethod
def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]:
"""
Returns the leaked dependencies in current workspace
"""
dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
if not dependencies:
return []
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
def _generate_aes_key(self, tenant_id: str) -> bytes:
"""Generate AES key based on tenant_id"""
return hashlib.sha256(tenant_id.encode()).digest()
def encrypt_dataset_id(self, dataset_id: str, tenant_id: str) -> str:
"""Encrypt dataset_id using AES-CBC mode"""
key = self._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size))
return base64.b64encode(ct_bytes).decode()
def decrypt_dataset_id(self, encrypted_data: str, tenant_id: str) -> str | None:
"""AES decryption"""
try:
key = self._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
return pt.decode()
except Exception:
return None
def create_rag_pipeline_dataset(
self,
tenant_id: str,
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
):
if rag_pipeline_dataset_create_entity.name:
# check if dataset name already exists
if (
self._session.query(Dataset)
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
):
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")
else:
# generate a random name as Untitled 1 2 3 ...
datasets = self._session.query(Dataset).filter_by(tenant_id=tenant_id).all()
names = [dataset.name for dataset in datasets]
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
names,
"Untitled",
)
account = cast(Account, current_user)
rag_pipeline_import_info: RagPipelineImportInfo = self.import_rag_pipeline(
account=account,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=None,
dataset_name=rag_pipeline_dataset_create_entity.name,
icon_info=rag_pipeline_dataset_create_entity.icon_info,
)
return {
"id": rag_pipeline_import_info.id,
"dataset_id": rag_pipeline_import_info.dataset_id,
"pipeline_id": rag_pipeline_import_info.pipeline_id,
"status": rag_pipeline_import_info.status,
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
"current_dsl_version": rag_pipeline_import_info.current_dsl_version,
"error": rag_pipeline_import_info.error,
}

View File

@@ -0,0 +1,23 @@
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
from core.plugin.impl.datasource import PluginDatasourceManager
from services.datasource_provider_service import DatasourceProviderService
class RagPipelineManageService:
@staticmethod
def list_rag_pipeline_datasources(tenant_id: str) -> list[PluginDatasourceProviderEntity]:
"""
list rag pipeline datasources
"""
# get all builtin providers
manager = PluginDatasourceManager()
datasources = manager.fetch_datasource_providers(tenant_id)
for datasource in datasources:
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
if credentials:
datasource.is_authorized = True
return datasources

View File

@@ -0,0 +1,106 @@
import json
import logging
from collections.abc import Callable, Sequence
from functools import cached_property
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from services.feature_service import FeatureService
from services.file_service import FileService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
class RagPipelineTaskProxy:
# Default uploaded file name for rag pipeline invoke entities
_RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME = "rag_pipeline_invoke_entities.json"
def __init__(
self, dataset_tenant_id: str, user_id: str, rag_pipeline_invoke_entities: Sequence[RagPipelineInvokeEntity]
):
self._dataset_tenant_id = dataset_tenant_id
self._user_id = user_id
self._rag_pipeline_invoke_entities = rag_pipeline_invoke_entities
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(dataset_tenant_id, "pipeline")
@cached_property
def features(self):
return FeatureService.get_features(self._dataset_tenant_id)
def _upload_invoke_entities(self) -> str:
text = [item.model_dump() for item in self._rag_pipeline_invoke_entities]
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(
json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id
)
return upload_file.id
def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
logger.info("send file %s to direct queue", upload_file_id)
task_func.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id=self._dataset_tenant_id,
)
def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
logger.info("send file %s to tenant queue", upload_file_id)
if self._tenant_isolated_task_queue.get_task_key():
# Add to waiting queue using List operations (lpush)
self._tenant_isolated_task_queue.push_tasks([upload_file_id])
logger.info("push tasks: %s", upload_file_id)
else:
# Set flag and execute task
self._tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id=self._dataset_tenant_id,
)
logger.info("init tasks: %s", upload_file_id)
def _send_to_default_tenant_queue(self, upload_file_id: str):
self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task)
def _send_to_priority_tenant_queue(self, upload_file_id: str):
self._send_to_tenant_queue(upload_file_id, priority_rag_pipeline_run_task)
def _send_to_priority_direct_queue(self, upload_file_id: str):
self._send_to_direct_queue(upload_file_id, priority_rag_pipeline_run_task)
def _dispatch(self):
upload_file_id = self._upload_invoke_entities()
if not upload_file_id:
raise ValueError("upload_file_id is empty")
logger.info(
"dispatch args: %s - %s - %s",
self._dataset_tenant_id,
self.features.billing.enabled,
self.features.billing.subscription.plan,
)
# dispatch to different pipeline queue with tenant isolation when billing enabled
if self.features.billing.enabled:
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
# dispatch to normal pipeline queue with tenant isolation for sandbox plan
self._send_to_default_tenant_queue(upload_file_id)
else:
# dispatch to priority pipeline queue with tenant isolation for other plans
self._send_to_priority_tenant_queue(upload_file_id)
else:
# dispatch to priority pipeline queue without tenant isolation for others, e.g.: self-hosted or enterprise
self._send_to_priority_direct_queue(upload_file_id)
def delay(self):
if not self._rag_pipeline_invoke_entities:
logger.warning(
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s",
self._dataset_tenant_id,
self._user_id,
)
return
self._dispatch()

Some files were not shown because too many files have changed in this diff Show More