This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

@@ -0,0 +1,474 @@
import json
import logging
from collections.abc import Mapping
from typing import Any, cast
from httpx import get
from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ApiProviderSchemaType,
)
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_tool_provider_encrypter
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
from models.tools import ApiToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class ApiToolManageService:
@staticmethod
def parser_api_schema(schema: str) -> Mapping[str, Any]:
"""
parse api schema to tool bundle
"""
try:
warnings: dict[str, str] = {}
try:
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
except Exception as e:
raise ValueError(f"invalid schema: {str(e)}")
credentials_schema = [
ProviderConfig(
name="auth_type",
type=ProviderConfig.Type.SELECT,
required=True,
default="none",
options=[
ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="")),
ProviderConfig.Option(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
],
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
),
ProviderConfig(
name="api_key_header",
type=ProviderConfig.Type.TEXT_INPUT,
required=False,
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key headerX-API-KEY"),
default="api_key",
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
),
ProviderConfig(
name="api_key_value",
type=ProviderConfig.Type.TEXT_INPUT,
required=False,
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
default="",
),
]
return cast(
Mapping,
jsonable_encoder(
{
"schema_type": schema_type,
"parameters_schema": tool_bundles,
"credentials_schema": credentials_schema,
"warning": warnings,
}
),
)
except Exception as e:
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
"""
convert schema to tool bundles
:return: the list of tool bundles, description
"""
try:
return ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
except Exception as e:
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
def create_api_tool_provider(
user_id: str,
tenant_id: str,
provider_name: str,
icon: dict,
credentials: dict,
schema_type: str,
schema: str,
privacy_policy: str,
custom_disclaimer: str,
labels: list[str],
):
"""
create api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema}")
provider_name = provider_name.strip()
# check if the provider exists
provider = (
db.session.query(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if provider is not None:
raise ValueError(f"provider {provider_name} already exists")
# parse openapi to tool bundle
extra_info: dict[str, str] = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
if len(tool_bundles) > 100:
raise ValueError("the number of apis should be less than 100")
# create db provider
db_provider = ApiToolProvider(
tenant_id=tenant_id,
user_id=user_id,
name=provider_name,
icon=json.dumps(icon),
schema=schema,
description=extra_info.get("description", ""),
schema_type_str=schema_type,
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
credentials_str="{}",
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
)
if "auth_type" not in credentials:
raise ValueError("auth_type is required")
# get auth type, none or api key
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
controller=provider_controller,
)
db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
db.session.add(db_provider)
db.session.commit()
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
return {"result": "success"}
@staticmethod
def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str):
"""
get api tool provider remote schema
"""
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko)"
" Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
"Accept": "*/*",
}
try:
response = get(url, headers=headers, timeout=10)
if response.status_code != 200:
raise ValueError(f"Got status code {response.status_code}")
schema = response.text
# try to parse schema, avoid SSRF attack
ApiToolManageService.parser_api_schema(schema)
except Exception:
logger.exception("parse api schema error")
raise ValueError("invalid schema, please check the url you provided")
return {"schema": schema}
@staticmethod
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]:
"""
list api tool provider tools
"""
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if provider is None:
raise ValueError(f"you have not added provider {provider_name}")
controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
labels = ToolLabelManager.get_tool_labels(controller)
return [
ToolTransformService.convert_tool_entity_to_api_entity(
tool_bundle,
tenant_id=tenant_id,
labels=labels,
)
for tool_bundle in provider.tools
]
@staticmethod
def update_api_tool_provider(
user_id: str,
tenant_id: str,
provider_name: str,
original_provider: str,
icon: dict,
credentials: dict,
schema_type: str,
schema: str,
privacy_policy: str,
custom_disclaimer: str,
labels: list[str],
):
"""
update api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema}")
provider_name = provider_name.strip()
# check if the provider exists
provider = (
db.session.query(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == original_provider,
)
.first()
)
if provider is None:
raise ValueError(f"api provider {provider_name} does not exists")
# parse openapi to tool bundle
extra_info: dict[str, str] = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
# update db provider
provider.name = provider_name
provider.icon = json.dumps(icon)
provider.schema = schema
provider.description = extra_info.get("description", "")
provider.schema_type_str = ApiProviderSchemaType.OPENAPI
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
provider.privacy_policy = privacy_policy
provider.custom_disclaimer = custom_disclaimer
if "auth_type" not in credentials:
raise ValueError("auth_type is required")
# get auth type, none or api key
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
provider_controller = ApiToolProviderController.from_db(provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists
encrypter, cache = create_tool_provider_encrypter(
tenant_id=tenant_id,
controller=provider_controller,
)
original_credentials = encrypter.decrypt(provider.credentials)
masked_credentials = encrypter.mask_plugin_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
credentials = dict(encrypter.encrypt(credentials))
provider.credentials_str = json.dumps(credentials)
db.session.add(provider)
db.session.commit()
# delete cache
cache.delete()
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
return {"result": "success"}
@staticmethod
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
"""
delete tool provider
"""
provider = (
db.session.query(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if provider is None:
raise ValueError(f"you have not added provider {provider_name}")
db.session.delete(provider)
db.session.commit()
return {"result": "success"}
@staticmethod
def get_api_tool_provider(user_id: str, tenant_id: str, provider: str):
"""
get api tool provider
"""
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
@staticmethod
def test_api_tool_preview(
tenant_id: str,
provider_name: str,
tool_name: str,
credentials: dict,
parameters: dict,
schema_type: str,
schema: str,
):
"""
test api tool before adding api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema_type}")
try:
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
except Exception:
raise ValueError("invalid schema")
# get tool bundle
tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)
if tool_bundle is None:
raise ValueError(f"invalid tool name {tool_name}")
db_provider = (
db.session.query(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if not db_provider:
# create a fake db provider
db_provider = ApiToolProvider(
tenant_id="",
user_id="",
name="",
icon="",
schema=schema,
description="",
schema_type_str=ApiProviderSchemaType.OPENAPI,
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
credentials_str=json.dumps(credentials),
)
if "auth_type" not in credentials:
raise ValueError("auth_type is required")
# get auth type, none or api key
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
# decrypt credentials
if db_provider.id:
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
controller=provider_controller,
)
decrypted_credentials = encrypter.decrypt(credentials)
# check if the credential has changed, save the original credential
masked_credentials = encrypter.mask_plugin_credentials(decrypted_credentials)
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = decrypted_credentials[name]
try:
provider_controller.validate_credentials_format(credentials)
# get tool
tool = provider_controller.get_tool(tool_name)
tool = tool.fork_tool_runtime(
runtime=ToolRuntime(
credentials=credentials,
tenant_id=tenant_id,
)
)
result = tool.validate_credentials(credentials, parameters)
except Exception as e:
return {"error": str(e)}
return {"result": result or "empty response"}
@staticmethod
def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
"""
list api tools
"""
# get all api providers
db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
result: list[ToolProviderApiEntity] = []
for provider in db_providers:
# convert provider controller to user provider
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
labels = ToolLabelManager.get_tool_labels(provider_controller)
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller, db_provider=provider, decrypt_credentials=True
)
user_provider.labels = labels
# add icon
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_provider)
tools = provider_controller.get_tools(tenant_id=tenant_id)
for tool in tools or []:
user_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id, tool=tool, labels=labels
)
)
result.append(user_provider)
return result

View File

@@ -0,0 +1,727 @@
import json
import logging
from collections.abc import Mapping
from pathlib import Path
from typing import Any
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.name_generator import generate_incremental_name
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import (
ToolApiEntity,
ToolProviderApiEntity,
ToolProviderCredentialApiEntity,
ToolProviderCredentialInfoApiEntity,
)
from core.tools.errors import ToolProviderNotFoundError
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_provider_encrypter
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider_ids import ToolProviderID
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
from services.plugin.plugin_service import PluginService
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class BuiltinToolManageService:
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
@staticmethod
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
"""
delete custom oauth client params
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine) as session:
session.query(ToolOAuthTenantClient).filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
).delete()
session.commit()
return {"result": "success"}
@staticmethod
def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str):
"""
get builtin tool provider oauth client schema
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
tenant_id, provider.plugin_unique_identifier
)
is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled(
tenant_id, provider_name
)
is_system_oauth_params_exists = verified and BuiltinToolManageService.is_oauth_system_client_exists(
provider_name
)
result = {
"schema": provider.get_oauth_client_schema(),
"is_oauth_custom_client_enabled": is_oauth_custom_client_enabled,
"is_system_oauth_params_exists": is_system_oauth_params_exists,
"client_params": BuiltinToolManageService.get_custom_oauth_client_params(tenant_id, provider_name),
"redirect_uri": f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback",
}
return result
@staticmethod
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
"""
list builtin tool provider tools
:param tenant_id: the id of the tenant
:param provider: the name of the provider
:return: the list of tools
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tools = provider_controller.get_tools()
result: list[ToolApiEntity] = []
for tool in tools or []:
result.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
return result
@staticmethod
def get_builtin_tool_provider_info(tenant_id: str, provider: str):
"""
get builtin tool provider info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
# check if user has added the provider
builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
if builtin_provider is None:
raise ValueError(f"you have not added provider {provider}")
entity = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=builtin_provider,
decrypt_credentials=True,
)
entity.original_credentials = {}
return entity
@staticmethod
def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str):
"""
list builtin provider credentials schema
:param credential_type: credential type
:param provider_name: the name of the provider
:param tenant_id: the id of the tenant
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return provider.get_credentials_schema_by_type(credential_type)
@staticmethod
def update_builtin_tool_provider(
user_id: str,
tenant_id: str,
provider: str,
credential_id: str,
credentials: dict | None = None,
name: str | None = None,
):
"""
update builtin tool provider
"""
with Session(db.engine) as session:
# get if the provider exists
db_provider = (
session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
try:
if CredentialType.of(db_provider.credential_type).is_editable() and credentials:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials")
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, db_provider, provider, provider_controller
)
original_credentials = encrypter.decrypt(db_provider.credentials)
new_credentials: dict = {
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
if CredentialType.of(db_provider.credential_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, new_credentials)
# encrypt credentials
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
cache.delete()
# update name if provided
if name and name != db_provider.name:
# check if the name is already used
if session.scalar(
select(
exists().where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.name == name,
)
)
):
raise ValueError(f"the credential name '{name}' is already used")
db_provider.name = name
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@staticmethod
def add_builtin_tool_provider(
user_id: str,
api_type: CredentialType,
tenant_id: str,
provider: str,
credentials: dict,
expires_at: int = -1,
name: str | None = None,
):
"""
add builtin tool provider
"""
with Session(db.engine) as session:
try:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials")
provider_count = (
session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
)
# check if the provider count is reached the limit
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
raise ValueError(f"you have reached the maximum number of providers for {provider}")
# validate credentials if allowed
if CredentialType.of(api_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, credentials)
# generate name if not provided
if name is None or name == "":
name = BuiltinToolManageService.generate_builtin_tool_provider_name(
session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
)
else:
# check if the name is already used
if session.scalar(
select(
exists().where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.name == name,
)
)
):
raise ValueError(f"the credential name '{name}' is already used")
# create encrypter
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(api_type)
],
cache=NoOpProviderCredentialCache(),
)
db_provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=api_type.value,
name=name,
expires_at=expires_at if expires_at is not None else -1,
)
session.add(db_provider)
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@staticmethod
def create_tool_encrypter(
tenant_id: str,
db_provider: BuiltinToolProvider,
provider: str,
provider_controller: BuiltinToolProviderController,
):
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type)
],
cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id),
)
return encrypter, cache
@staticmethod
def generate_builtin_tool_provider_name(
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
db_providers = (
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type.value,
)
.order_by(BuiltinToolProvider.created_at.desc())
.all()
)
return generate_incremental_name(
[provider.name for provider in db_providers],
f"{credential_type.get_name()}",
)
@staticmethod
def get_builtin_tool_provider_credentials(
tenant_id: str, provider_name: str
) -> list[ToolProviderCredentialApiEntity]:
"""
get builtin tool provider credentials
"""
with db.session.no_autoflush:
providers = (
db.session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider_name)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.all()
)
if len(providers) == 0:
return []
default_provider = providers[0]
default_provider.is_default = True
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
credentials: list[ToolProviderCredentialApiEntity] = []
for provider in providers:
encrypter, _ = BuiltinToolManageService.create_tool_encrypter(
tenant_id, provider, provider.provider, provider_controller
)
decrypt_credential = encrypter.mask_plugin_credentials(encrypter.decrypt(provider.credentials))
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,
credentials=dict(decrypt_credential),
)
credentials.append(credential_entity)
return credentials
@staticmethod
def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity:
"""
get builtin tool provider credential info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
supported_credential_types = provider_controller.get_supported_credential_types()
credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
credential_info = ToolProviderCredentialInfoApiEntity(
supported_credential_types=supported_credential_types,
is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
credentials=credentials,
)
return credential_info
@staticmethod
def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str):
"""
delete tool provider
"""
with Session(db.engine) as session:
db_provider = (
session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
session.delete(db_provider)
session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
_, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, db_provider, provider, provider_controller
)
cache.delete()
return {"result": "success"}
@staticmethod
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
"""
set default provider
"""
with Session(db.engine) as session:
# get provider
target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
).update({"is_default": False})
# set new default provider
target_provider.is_default = True
session.commit()
return {"result": "success"}
@staticmethod
def is_oauth_system_client_exists(provider_name: str) -> bool:
"""
check if oauth system client exists
"""
tool_provider = ToolProviderID(provider_name)
with Session(db.engine, autoflush=False) as session:
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
)
return system_client is not None
@staticmethod
def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
"""
check if oauth custom client is enabled
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine, autoflush=False) as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
)
.first()
)
return user_client is not None and user_client.enabled
@staticmethod
def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None:
"""
get builtin tool provider
"""
tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
with Session(db.engine, autoflush=False) as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
)
.first()
)
oauth_params: Mapping[str, Any] | None = None
if user_client:
oauth_params = encrypter.decrypt(user_client.oauth_params)
return oauth_params
# only verified provider can use official oauth client
is_verified = not isinstance(
provider_controller, PluginToolProviderController
) or PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
if not is_verified:
return oauth_params
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
)
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")
return oauth_params
@staticmethod
def get_builtin_tool_provider_icon(provider: str):
"""
get tool provider icon and it's mimetype
"""
icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider)
icon_bytes = Path(icon_path).read_bytes()
return icon_bytes, mime_type
@staticmethod
def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
"""
list builtin tools
"""
# get all builtin providers
provider_controllers = ToolManager.list_builtin_providers(tenant_id)
# get all user added providers
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
# rewrite db_providers
for db_provider in db_providers:
db_provider.provider = str(ToolProviderID(db_provider.provider))
# find provider
def find_provider(provider):
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
result: list[ToolProviderApiEntity] = []
for provider_controller in provider_controllers:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider_controller,
name_func=lambda x: x.entity.identity.name,
):
continue
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=find_provider(provider_controller.entity.identity.name),
decrypt_credentials=True,
)
# add icon
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
tools = provider_controller.get_tools()
for tool in tools or []:
user_builtin_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
result.append(user_builtin_provider)
except Exception as e:
raise e
return BuiltinToolProviderSort.sort(result)
@staticmethod
def get_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider
"""
with Session(db.engine, autoflush=False) as session:
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
provider = (
session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name,
)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
else:
provider = (
session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name),
)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
if provider is None:
return None
provider.provider = ToolProviderID(provider.provider).to_string()
return provider
except Exception:
# it's an old provider without organization
return (
session.query(BuiltinToolProvider)
.where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
@staticmethod
def save_custom_oauth_client_params(
tenant_id: str,
provider: str,
client_params: dict | None = None,
enable_oauth_custom_client: bool | None = None,
):
"""
setup oauth custom client
"""
if client_params is None and enable_oauth_custom_client is None:
return {"result": "success"}
tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
with Session(db.engine) as session:
custom_client_params = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
.first()
)
# if the record does not exist, create a basic record
if custom_client_params is None:
custom_client_params = ToolOAuthTenantClient(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
session.add(custom_client_params)
if client_params is not None:
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
original_params = encrypter.decrypt(custom_client_params.oauth_params)
new_params = {
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
for key, value in client_params.items()
}
custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
if enable_oauth_custom_client is not None:
custom_client_params.enabled = enable_oauth_custom_client
session.commit()
return {"result": "success"}
@staticmethod
def get_custom_oauth_client_params(tenant_id: str, provider: str):
"""
get custom oauth client params
"""
with Session(db.engine) as session:
tool_provider = ToolProviderID(provider)
custom_oauth_client_params: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
.first()
)
if custom_oauth_client_params is None:
return {}
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
if not isinstance(provider_controller, BuiltinToolProviderController):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
return encrypter.mask_plugin_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))

View File

@@ -0,0 +1,734 @@
import hashlib
import json
import logging
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any
from urllib.parse import urlparse
from pydantic import BaseModel, Field
from sqlalchemy import or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.utils.encryption import ProviderConfigEncrypter
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
# Constants
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
CLIENT_NAME = "Dify"
EMPTY_TOOLS_JSON = "[]"
EMPTY_CREDENTIALS_JSON = "{}"
class OAuthDataType(StrEnum):
"""Types of OAuth data that can be saved."""
TOKENS = "tokens"
CLIENT_INFO = "client_info"
CODE_VERIFIER = "code_verifier"
MIXED = "mixed"
class ReconnectResult(BaseModel):
"""Result of reconnecting to an MCP provider"""
authed: bool = Field(description="Whether the provider is authenticated")
tools: str = Field(description="JSON string of tool list")
encrypted_credentials: str = Field(description="JSON string of encrypted credentials")
class ServerUrlValidationResult(BaseModel):
"""Result of server URL validation check"""
needs_validation: bool
validation_passed: bool = False
reconnect_result: ReconnectResult | None = None
encrypted_server_url: str | None = None
server_url_hash: str | None = None
@property
def should_update_server_url(self) -> bool:
"""Check if server URL should be updated based on validation result"""
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
class MCPToolManageService:
"""Service class for managing MCP tools and providers."""
def __init__(self, session: Session):
self._session = session
# ========== Provider CRUD Operations ==========
def get_provider(
self, *, provider_id: str | None = None, server_identifier: str | None = None, tenant_id: str
) -> MCPToolProvider:
"""
Get MCP provider by ID or server identifier.
Args:
provider_id: Provider ID (UUID)
server_identifier: Server identifier
tenant_id: Tenant ID
Returns:
MCPToolProvider instance
Raises:
ValueError: If provider not found
"""
if server_identifier:
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier
)
else:
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id
)
provider = self._session.scalar(stmt)
if not provider:
raise ValueError("MCP tool not found")
return provider
def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity:
"""Get provider entity by ID or server identifier."""
if by_server_id:
db_provider = self.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
else:
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
return db_provider.to_entity()
def create_provider(
self,
*,
tenant_id: str,
name: str,
server_url: str,
user_id: str,
icon: str,
icon_type: str,
icon_background: str,
server_identifier: str,
configuration: MCPConfiguration,
authentication: MCPAuthentication | None = None,
headers: dict[str, str] | None = None,
) -> ToolProviderApiEntity:
"""Create a new MCP provider."""
# Validate URL format
if not self._is_valid_url(server_url):
raise ValueError("Server URL is not valid.")
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
# Check for existing provider
self._check_provider_exists(tenant_id, name, server_url_hash, server_identifier)
# Encrypt sensitive data
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
encrypted_credentials = None
if authentication is not None and authentication.client_id:
encrypted_credentials = self._build_and_encrypt_credentials(
authentication.client_id, authentication.client_secret, tenant_id
)
# Create provider
mcp_tool = MCPToolProvider(
tenant_id=tenant_id,
name=name,
server_url=encrypted_server_url,
server_url_hash=server_url_hash,
user_id=user_id,
authed=False,
tools=EMPTY_TOOLS_JSON,
icon=self._prepare_icon(icon, icon_type, icon_background),
server_identifier=server_identifier,
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
encrypted_headers=encrypted_headers,
encrypted_credentials=encrypted_credentials,
)
self._session.add(mcp_tool)
self._session.flush()
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
return mcp_providers
def update_provider(
self,
*,
tenant_id: str,
provider_id: str,
name: str,
server_url: str,
icon: str,
icon_type: str,
icon_background: str,
server_identifier: str,
headers: dict[str, str] | None = None,
configuration: MCPConfiguration,
authentication: MCPAuthentication | None = None,
validation_result: ServerUrlValidationResult | None = None,
) -> None:
"""
Update an MCP provider.
Args:
validation_result: Pre-validation result from validate_server_url_change.
If provided and contains reconnect_result, it will be used
instead of performing network operations.
"""
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
# Check for duplicate name (excluding current provider)
if name != mcp_provider.name:
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id,
MCPToolProvider.name == name,
MCPToolProvider.id != provider_id,
)
existing_provider = self._session.scalar(stmt)
if existing_provider:
raise ValueError(f"MCP tool {name} already exists")
# Get URL update data from validation result
encrypted_server_url = None
server_url_hash = None
reconnect_result = None
if validation_result and validation_result.encrypted_server_url:
# Use all data from validation result
encrypted_server_url = validation_result.encrypted_server_url
server_url_hash = validation_result.server_url_hash
reconnect_result = validation_result.reconnect_result
try:
# Update basic fields
mcp_provider.updated_at = datetime.now()
mcp_provider.name = name
mcp_provider.icon = self._prepare_icon(icon, icon_type, icon_background)
mcp_provider.server_identifier = server_identifier
# Update server URL if changed
if encrypted_server_url and server_url_hash:
mcp_provider.server_url = encrypted_server_url
mcp_provider.server_url_hash = server_url_hash
if reconnect_result:
mcp_provider.authed = reconnect_result.authed
mcp_provider.tools = reconnect_result.tools
mcp_provider.encrypted_credentials = reconnect_result.encrypted_credentials
# Update optional configuration fields
self._update_optional_fields(mcp_provider, configuration)
# Update headers if provided
if headers is not None:
mcp_provider.encrypted_headers = self._process_headers(headers, mcp_provider, tenant_id)
# Update credentials if provided
if authentication and authentication.client_id:
mcp_provider.encrypted_credentials = self._process_credentials(authentication, mcp_provider, tenant_id)
# Flush changes to database
self._session.flush()
except IntegrityError as e:
self._handle_integrity_error(e, name, server_url, server_identifier)
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
"""Delete an MCP provider."""
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
self._session.delete(mcp_tool)
def list_providers(
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
) -> list[ToolProviderApiEntity]:
"""List all MCP providers for a tenant.
Args:
tenant_id: Tenant ID
for_list: If True, return provider ID; if False, return server identifier
include_sensitive: If False, skip expensive decryption operations (default: True for backward compatibility)
"""
from models.account import Account
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
mcp_providers = self._session.scalars(stmt).all()
if not mcp_providers:
return []
# Batch query all users to avoid N+1 problem
user_ids = {provider.user_id for provider in mcp_providers}
users = self._session.query(Account).where(Account.id.in_(user_ids)).all()
user_name_map = {user.id: user.name for user in users}
return [
ToolTransformService.mcp_provider_to_user_provider(
provider,
for_list=for_list,
user_name=user_name_map.get(provider.user_id),
include_sensitive=include_sensitive,
)
for provider in mcp_providers
]
# ========== Tool Operations ==========
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
"""List tools from remote MCP server."""
# Load provider and convert to entity
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
provider_entity = db_provider.to_entity()
# Verify authentication
if not provider_entity.authed:
raise ValueError("Please auth the tool first")
# Prepare headers with auth token
headers = self._prepare_auth_headers(provider_entity)
# Retrieve tools from remote server
server_url = provider_entity.decrypt_server_url()
try:
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
# Update database with retrieved tools
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
db_provider.authed = True
db_provider.updated_at = datetime.now()
self._session.flush()
# Build API response
return self._build_tool_provider_response(db_provider, provider_entity, tools)
# ========== OAuth and Credentials Operations ==========
def update_provider_credentials(
self, *, provider_id: str, tenant_id: str, credentials: dict[str, Any], authed: bool | None = None
) -> None:
"""
Update provider credentials with encryption.
Args:
provider_id: Provider ID
tenant_id: Tenant ID
credentials: Credentials to save
authed: Whether provider is authenticated (None means keep current state)
"""
from core.tools.mcp_tool.provider import MCPToolProviderController
# Get provider from current session
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
# Encrypt new credentials
provider_controller = MCPToolProviderController.from_db(provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=provider.tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_config_cache=NoOpProviderCredentialCache(),
)
encrypted_credentials = tool_configuration.encrypt(credentials)
# Update provider
provider.updated_at = datetime.now()
provider.encrypted_credentials = json.dumps({**provider.credentials, **encrypted_credentials})
if authed is not None:
provider.authed = authed
if not authed:
provider.tools = EMPTY_TOOLS_JSON
# Flush changes to database
self._session.flush()
def save_oauth_data(
self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: OAuthDataType = OAuthDataType.MIXED
) -> None:
"""
Save OAuth-related data (tokens, client info, code verifier).
Args:
provider_id: Provider ID
tenant_id: Tenant ID
data: Data to save (tokens, client info, or code verifier)
data_type: Type of OAuth data to save
"""
# Determine if this makes the provider authenticated
authed = (
data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None
)
# update_provider_credentials will validate provider existence
self.update_provider_credentials(provider_id=provider_id, tenant_id=tenant_id, credentials=data, authed=authed)
def clear_provider_credentials(self, *, provider_id: str, tenant_id: str) -> None:
"""
Clear all credentials for a provider.
Args:
provider_id: Provider ID
tenant_id: Tenant ID
"""
# Get provider from current session
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
provider.tools = EMPTY_TOOLS_JSON
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
provider.updated_at = datetime.now()
provider.authed = False
# ========== Private Helper Methods ==========
def _check_provider_exists(self, tenant_id: str, name: str, server_url_hash: str, server_identifier: str) -> None:
"""Check if provider with same attributes already exists."""
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id,
or_(
MCPToolProvider.name == name,
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
)
existing_provider = self._session.scalar(stmt)
if existing_provider:
if existing_provider.name == name:
raise ValueError(f"MCP tool {name} already exists")
if existing_provider.server_url_hash == server_url_hash:
raise ValueError("MCP tool with this server URL already exists")
if existing_provider.server_identifier == server_identifier:
raise ValueError(f"MCP tool {server_identifier} already exists")
def _prepare_icon(self, icon: str, icon_type: str, icon_background: str) -> str:
"""Prepare icon data for storage."""
if icon_type == "emoji":
return json.dumps({"content": icon, "background": icon_background})
return icon
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> Mapping[str, str]:
"""Encrypt specified fields in a dictionary.
Args:
data: Dictionary containing data to encrypt
secret_fields: List of field names to encrypt
tenant_id: Tenant ID for encryption
Returns:
JSON string of encrypted data
"""
from core.entities.provider_entities import BasicProviderConfig
from core.tools.utils.encryption import create_provider_encrypter
# Create config for secret fields
config = [
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=field) for field in secret_fields
]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
encrypted_data = encrypter_instance.encrypt(data)
return encrypted_data
def _prepare_encrypted_dict(self, headers: dict[str, str], tenant_id: str) -> str:
"""Encrypt headers and prepare for storage."""
# All headers are treated as secret
return json.dumps(self._encrypt_dict_fields(headers, list(headers.keys()), tenant_id))
def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]:
"""Prepare headers with OAuth token if available."""
headers = provider_entity.decrypt_headers()
tokens = provider_entity.retrieve_tokens()
if tokens:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
return headers
def _retrieve_remote_mcp_tools(
self,
server_url: str,
headers: dict[str, str],
provider_entity: MCPProviderEntity,
):
"""Retrieve tools from remote MCP server."""
with MCPClientWithAuthRetry(
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
provider_entity=provider_entity,
) as mcp_client:
return mcp_client.list_tools()
def execute_auth_actions(self, auth_result: Any) -> dict[str, str]:
"""
Execute the actions returned by the auth function.
This method processes the AuthResult and performs the necessary database operations.
Args:
auth_result: The result from the auth function
Returns:
The response from the auth result
"""
from core.mcp.entities import AuthAction, AuthActionType
action: AuthAction
for action in auth_result.actions:
if action.provider_id is None or action.tenant_id is None:
continue
if action.action_type == AuthActionType.SAVE_CLIENT_INFO:
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CLIENT_INFO)
elif action.action_type == AuthActionType.SAVE_TOKENS:
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.TOKENS)
elif action.action_type == AuthActionType.SAVE_CODE_VERIFIER:
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CODE_VERIFIER)
return auth_result.response
def auth_with_actions(
self,
provider_entity: MCPProviderEntity,
authorization_code: str | None = None,
resource_metadata_url: str | None = None,
scope_hint: str | None = None,
) -> dict[str, str]:
"""
Perform authentication and execute all resulting actions.
This method is used by MCPClientWithAuthRetry for automatic re-authentication.
Args:
provider_entity: The MCP provider entity
authorization_code: Optional authorization code
resource_metadata_url: Optional Protected Resource Metadata URL from WWW-Authenticate
scope_hint: Optional scope hint from WWW-Authenticate header
Returns:
Response dictionary from auth result
"""
auth_result = auth(
provider_entity,
authorization_code,
resource_metadata_url=resource_metadata_url,
scope_hint=scope_hint,
)
return self.execute_auth_actions(auth_result)
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
"""Attempt to reconnect to MCP provider with new server URL."""
provider_entity = provider.to_entity()
headers = provider_entity.headers
try:
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
return ReconnectResult(
authed=True,
tools=json.dumps([tool.model_dump() for tool in tools]),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
except MCPAuthError:
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
def validate_server_url_change(
self, *, tenant_id: str, provider_id: str, new_server_url: str
) -> ServerUrlValidationResult:
"""
Validate server URL change by attempting to connect to the new server.
This method should be called BEFORE update_provider to perform network operations
outside of the database transaction.
Returns:
ServerUrlValidationResult: Validation result with connection status and tools if successful
"""
# Handle hidden/unchanged URL
if UNCHANGED_SERVER_URL_PLACEHOLDER in new_server_url:
return ServerUrlValidationResult(needs_validation=False)
# Validate URL format
if not self._is_valid_url(new_server_url):
raise ValueError("Server URL is not valid.")
# Always encrypt and hash the URL
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
# Get current provider
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
# Check if URL is actually different
if new_server_url_hash == provider.server_url_hash:
# URL hasn't changed, but still return the encrypted data
return ServerUrlValidationResult(
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
)
# Perform validation by attempting to connect
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
return ServerUrlValidationResult(
needs_validation=True,
validation_passed=True,
reconnect_result=reconnect_result,
encrypted_server_url=encrypted_server_url,
server_url_hash=new_server_url_hash,
)
def _build_tool_provider_response(
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
) -> ToolProviderApiEntity:
"""Build API response for tool provider."""
user = db_provider.load_user()
response = provider_entity.to_api_response(
user_name=user.name if user else None,
)
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools)
response["plugin_unique_identifier"] = provider_entity.provider_id
return ToolProviderApiEntity(**response)
def _handle_integrity_error(
self, error: IntegrityError, name: str, server_url: str, server_identifier: str
) -> None:
"""Handle database integrity errors with user-friendly messages."""
error_msg = str(error.orig)
if "unique_mcp_provider_name" in error_msg:
raise ValueError(f"MCP tool {name} already exists")
if "unique_mcp_provider_server_url" in error_msg:
raise ValueError(f"MCP tool {server_url} already exists")
if "unique_mcp_provider_server_identifier" in error_msg:
raise ValueError(f"MCP tool {server_identifier} already exists")
raise
def _is_valid_url(self, url: str) -> bool:
"""Validate URL format."""
if not url:
return False
try:
parsed = urlparse(url)
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
except (ValueError, TypeError):
return False
def _update_optional_fields(self, mcp_provider: MCPToolProvider, configuration: MCPConfiguration) -> None:
"""Update optional configuration fields using setattr for cleaner code."""
field_mapping = {"timeout": configuration.timeout, "sse_read_timeout": configuration.sse_read_timeout}
for field, value in field_mapping.items():
if value is not None:
setattr(mcp_provider, field, value)
def _process_headers(self, headers: dict[str, str], mcp_provider: MCPToolProvider, tenant_id: str) -> str | None:
"""Process headers update, handling empty dict to clear headers."""
if not headers:
return None
# Merge with existing headers to preserve masked values
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
return self._prepare_encrypted_dict(final_headers, tenant_id)
def _process_credentials(
self, authentication: MCPAuthentication, mcp_provider: MCPToolProvider, tenant_id: str
) -> str:
"""Process credentials update, handling masked values."""
# Merge with existing credentials
final_client_id, final_client_secret = self._merge_credentials_with_masked(
authentication.client_id, authentication.client_secret, mcp_provider
)
# Build and encrypt
return self._build_and_encrypt_credentials(final_client_id, final_client_secret, tenant_id)
def _merge_headers_with_masked(
self, incoming_headers: dict[str, str], mcp_provider: MCPToolProvider
) -> dict[str, str]:
"""Merge incoming headers with existing ones, preserving unchanged masked values.
Args:
incoming_headers: Headers from frontend (may contain masked values)
mcp_provider: The MCP provider instance
Returns:
Final headers dict with proper values (original for unchanged masked, new for changed)
"""
mcp_provider_entity = mcp_provider.to_entity()
existing_decrypted = mcp_provider_entity.decrypt_headers()
existing_masked = mcp_provider_entity.masked_headers()
return {
key: (str(existing_decrypted[key]) if key in existing_masked and value == existing_masked[key] else value)
for key, value in incoming_headers.items()
if key in existing_decrypted or value != existing_masked.get(key)
}
def _merge_credentials_with_masked(
self,
client_id: str,
client_secret: str | None,
mcp_provider: MCPToolProvider,
) -> tuple[
str,
str | None,
]:
"""Merge incoming credentials with existing ones, preserving unchanged masked values.
Args:
client_id: Client ID from frontend (may be masked)
client_secret: Client secret from frontend (may be masked)
mcp_provider: The MCP provider instance
Returns:
Tuple of (final_client_id, final_client_secret)
"""
mcp_provider_entity = mcp_provider.to_entity()
existing_decrypted = mcp_provider_entity.decrypt_credentials()
existing_masked = mcp_provider_entity.masked_credentials()
# Check if client_id is masked and unchanged
final_client_id = client_id
if existing_masked.get("client_id") and client_id == existing_masked["client_id"]:
# Use existing decrypted value
final_client_id = existing_decrypted.get("client_id", client_id)
# Check if client_secret is masked and unchanged
final_client_secret = client_secret
if existing_masked.get("client_secret") and client_secret == existing_masked["client_secret"]:
# Use existing decrypted value
final_client_secret = existing_decrypted.get("client_secret", client_secret)
return final_client_id, final_client_secret
def _build_and_encrypt_credentials(self, client_id: str, client_secret: str | None, tenant_id: str) -> str:
"""Build credentials and encrypt sensitive fields."""
# Create a flat structure with all credential data
credentials_data = {
"client_id": client_id,
"client_name": CLIENT_NAME,
"is_dynamic_registration": False,
}
secret_fields = []
if client_secret is not None:
credentials_data["encrypted_client_secret"] = client_secret
secret_fields = ["encrypted_client_secret"]
client_info = self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)
return json.dumps({"client_information": client_info})

View File

@@ -0,0 +1,8 @@
from core.tools.entities.tool_entities import ToolLabel
from core.tools.entities.values import default_tool_labels
class ToolLabelsService:
@classmethod
def list_tool_labels(cls) -> list[ToolLabel]:
return default_tool_labels

View File

@@ -0,0 +1,26 @@
import logging
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class ToolCommonService:
@staticmethod
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral | None = None):
"""
list tool providers
:return: the list of tool providers
"""
providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ)
# add icon
for provider in providers:
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
result = [provider.to_dict() for provider in providers]
return result

View File

@@ -0,0 +1,475 @@
import json
import logging
from collections.abc import Mapping
from typing import Any, Union
from pydantic import ValidationError
from yarl import URL
from configs import dify_config
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool
from core.plugin.entities.plugin_daemon import CredentialType, PluginDatasourceProviderEntity
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolParameter,
ToolProviderType,
)
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
class ToolTransformService:
@classmethod
def get_tool_provider_icon_url(
cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str]
) -> str | Mapping[str, str]:
"""
get tool provider icon url
"""
url_prefix = (
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
)
if provider_type == ToolProviderType.BUILT_IN:
return str(url_prefix / "builtin" / provider_name / "icon")
elif provider_type in {ToolProviderType.API, ToolProviderType.WORKFLOW}:
try:
if isinstance(icon, str):
return json.loads(icon)
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
elif provider_type == ToolProviderType.MCP:
return icon
return ""
@staticmethod
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
"""
repack provider
:param tenant_id: the tenant id
:param provider: the provider dict
"""
if isinstance(provider, dict) and "icon" in provider:
provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
)
elif isinstance(provider, ToolProviderApiEntity):
if provider.plugin_id:
if isinstance(provider.icon, str):
provider.icon = PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=provider.icon)
if isinstance(provider.icon_dark, str) and provider.icon_dark:
provider.icon_dark = PluginService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.icon_dark
)
else:
provider.icon = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
)
if provider.icon_dark:
provider.icon_dark = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon_dark
)
elif isinstance(provider, PluginDatasourceProviderEntity):
if provider.plugin_id:
if isinstance(provider.declaration.identity.icon, str):
provider.declaration.identity.icon = PluginService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.declaration.identity.icon
)
@classmethod
def builtin_provider_to_user_provider(
cls,
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
db_provider: BuiltinToolProvider | None,
decrypt_credentials: bool = True,
) -> ToolProviderApiEntity:
"""
convert provider controller to user provider
"""
result = ToolProviderApiEntity(
id=provider_controller.entity.identity.name,
author=provider_controller.entity.identity.author,
name=provider_controller.entity.identity.name,
description=provider_controller.entity.identity.description,
icon=provider_controller.entity.identity.icon,
icon_dark=provider_controller.entity.identity.icon_dark or "",
label=provider_controller.entity.identity.label,
type=ToolProviderType.BUILT_IN,
masked_credentials={},
is_team_authorization=False,
plugin_id=None,
tools=[],
labels=provider_controller.tool_labels,
)
if isinstance(provider_controller, PluginToolProviderController):
result.plugin_id = provider_controller.plugin_id
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
# get credentials schema
schema = {
x.to_basic_provider_config().name: x
for x in provider_controller.get_credentials_schema_by_type(
CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY
)
}
masked_creds = {}
for name in schema:
masked_creds[name] = ""
result.masked_credentials = masked_creds
# check if the provider need credentials
if not provider_controller.need_credentials:
result.is_team_authorization = True
result.allow_delete = False
elif db_provider:
result.is_team_authorization = True
if decrypt_credentials:
credentials = db_provider.credentials
if not db_provider.tenant_id:
raise ValueError(f"Required tenant_id is missing for BuiltinToolProvider with id {db_provider.id}")
# init tool configuration
encrypter, _ = create_provider_encrypter(
tenant_id=db_provider.tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(
CredentialType.of(db_provider.credential_type)
)
],
cache=ToolProviderCredentialsCache(
tenant_id=db_provider.tenant_id,
provider=db_provider.provider,
credential_id=db_provider.id,
),
)
# decrypt the credentials and mask the credentials
decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = encrypter.mask_plugin_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
result.original_credentials = decrypted_credentials
return result
@staticmethod
def api_provider_to_controller(
db_provider: ApiToolProvider,
) -> ApiToolProviderController:
"""
convert provider controller to user provider
"""
# package tool provider controller
auth_type = ApiProviderAuthType.NONE
credentials_auth_type = db_provider.credentials.get("auth_type")
if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility
auth_type = ApiProviderAuthType.API_KEY_HEADER
elif credentials_auth_type == "api_key_query":
auth_type = ApiProviderAuthType.API_KEY_QUERY
controller = ApiToolProviderController.from_db(
db_provider=db_provider,
auth_type=auth_type,
)
return controller
@staticmethod
def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
"""
convert provider controller to provider
"""
return WorkflowToolProviderController.from_db(db_provider)
@staticmethod
def workflow_provider_to_user_provider(
provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
):
"""
convert provider controller to user provider
"""
return ToolProviderApiEntity(
id=provider_controller.provider_id,
author=provider_controller.entity.identity.author,
name=provider_controller.entity.identity.name,
description=provider_controller.entity.identity.description,
icon=provider_controller.entity.identity.icon,
icon_dark=provider_controller.entity.identity.icon_dark or "",
label=provider_controller.entity.identity.label,
type=ToolProviderType.WORKFLOW,
masked_credentials={},
is_team_authorization=True,
plugin_id=None,
plugin_unique_identifier=None,
tools=[],
labels=labels or [],
)
@staticmethod
def mcp_provider_to_user_provider(
db_provider: MCPToolProvider,
for_list: bool = False,
user_name: str | None = None,
include_sensitive: bool = True,
) -> ToolProviderApiEntity:
from core.entities.mcp_provider import MCPConfiguration
# Use provided user_name to avoid N+1 query, fallback to load_user() if not provided
if user_name is None:
user = db_provider.load_user()
user_name = user.name if user else None
# Convert to entity and use its API response method
provider_entity = db_provider.to_entity()
response = provider_entity.to_api_response(user_name=user_name, include_sensitive=include_sensitive)
try:
mcp_tools = [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
except (ValidationError, json.JSONDecodeError):
mcp_tools = []
# Add additional fields specific to the transform
response["id"] = db_provider.server_identifier if not for_list else db_provider.id
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, mcp_tools, user_name=user_name)
response["server_identifier"] = db_provider.server_identifier
# Convert configuration dict to MCPConfiguration object
if "configuration" in response and isinstance(response["configuration"], dict):
response["configuration"] = MCPConfiguration(
timeout=float(response["configuration"]["timeout"]),
sse_read_timeout=float(response["configuration"]["sse_read_timeout"]),
)
return ToolProviderApiEntity(**response)
@staticmethod
def mcp_tool_to_user_tool(
mcp_provider: MCPToolProvider, tools: list[MCPTool], user_name: str | None = None
) -> list[ToolApiEntity]:
# Use provided user_name to avoid N+1 query, fallback to load_user() if not provided
if user_name is None:
user = mcp_provider.load_user()
user_name = user.name if user else "Anonymous"
return [
ToolApiEntity(
author=user_name or "Anonymous",
name=tool.name,
label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
description=I18nObject(en_US=tool.description or "", zh_Hans=tool.description or ""),
parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
labels=[],
output_schema=tool.outputSchema or {},
)
for tool in tools
]
@classmethod
def api_provider_to_user_provider(
cls,
provider_controller: ApiToolProviderController,
db_provider: ApiToolProvider,
decrypt_credentials: bool = True,
labels: list[str] | None = None,
) -> ToolProviderApiEntity:
"""
convert provider controller to user provider
"""
username = "Anonymous"
if db_provider.user is None:
raise ValueError(f"user is None for api provider {db_provider.id}")
try:
user = db_provider.user
if not user:
raise ValueError("user not found")
username = user.name
except Exception:
logger.exception("failed to get user name for api provider %s", db_provider.id)
# add provider into providers
credentials = db_provider.credentials
result = ToolProviderApiEntity(
id=db_provider.id,
author=username,
name=db_provider.name,
description=I18nObject(
en_US=db_provider.description,
zh_Hans=db_provider.description,
),
icon=db_provider.icon,
label=I18nObject(
en_US=db_provider.name,
zh_Hans=db_provider.name,
),
type=ToolProviderType.API,
plugin_id=None,
plugin_unique_identifier=None,
masked_credentials={},
is_team_authorization=True,
tools=[],
labels=labels or [],
)
if decrypt_credentials:
# init tool configuration
encrypter, _ = create_tool_provider_encrypter(
tenant_id=db_provider.tenant_id,
controller=provider_controller,
)
# decrypt the credentials and mask the credentials
decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = encrypter.mask_plugin_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
return result
@staticmethod
def convert_tool_entity_to_api_entity(
tool: ApiToolBundle | WorkflowTool | Tool,
tenant_id: str,
labels: list[str] | None = None,
) -> ToolApiEntity:
"""
convert tool to user tool
"""
if isinstance(tool, Tool):
# fork tool runtime
tool = tool.fork_tool_runtime(
runtime=ToolRuntime(
credentials={},
tenant_id=tenant_id,
)
)
# get tool parameters
base_parameters = tool.entity.parameters or []
# get tool runtime parameters
runtime_parameters = tool.get_runtime_parameters()
# merge parameters using a functional approach to avoid type issues
merged_parameters: list[ToolParameter] = []
# create a mapping of runtime parameters for quick lookup
runtime_param_map = {(rp.name, rp.form): rp for rp in runtime_parameters}
# process base parameters, replacing with runtime versions if they exist
for base_param in base_parameters:
key = (base_param.name, base_param.form)
if key in runtime_param_map:
merged_parameters.append(runtime_param_map[key])
else:
merged_parameters.append(base_param)
# add any runtime parameters that weren't in base parameters
for runtime_parameter in runtime_parameters:
if runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
# check if this parameter is already in merged_parameters
already_exists = any(
p.name == runtime_parameter.name and p.form == runtime_parameter.form for p in merged_parameters
)
if not already_exists:
merged_parameters.append(runtime_parameter)
return ToolApiEntity(
author=tool.entity.identity.author,
name=tool.entity.identity.name,
label=tool.entity.identity.label,
description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
output_schema=tool.entity.output_schema,
parameters=merged_parameters,
labels=labels or [],
)
else:
assert tool.operation_id
return ToolApiEntity(
author=tool.author,
name=tool.operation_id or "",
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
parameters=tool.parameters,
labels=labels or [],
)
@staticmethod
def convert_builtin_provider_to_credential_entity(
provider: BuiltinToolProvider, credentials: dict
) -> ToolProviderCredentialApiEntity:
return ToolProviderCredentialApiEntity(
id=provider.id,
name=provider.name,
provider=provider.provider,
credential_type=CredentialType.of(provider.credential_type),
is_default=provider.is_default,
credentials=credentials,
)
@staticmethod
def convert_mcp_schema_to_parameter(schema: dict[str, Any]) -> list["ToolParameter"]:
"""
Convert MCP JSON schema to tool parameters
:param schema: JSON schema dictionary
:return: list of ToolParameter instances
"""
def create_parameter(
name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None
) -> ToolParameter:
"""Create a ToolParameter instance with given attributes"""
input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
return ToolParameter(
name=name,
llm_description=description,
label=I18nObject(en_US=name),
form=ToolParameter.ToolParameterForm.LLM,
required=required,
type=ToolParameter.ToolParameterType(param_type),
human_description=I18nObject(en_US=description),
**input_schema_dict,
)
def process_properties(
props: dict[str, dict[str, Any]], required: list[str], prefix: str = ""
) -> list[ToolParameter]:
"""Process properties recursively"""
TYPE_MAPPING = {"integer": "number", "float": "number"}
COMPLEX_TYPES = ["array", "object"]
parameters = []
for name, prop in props.items():
current_description = prop.get("description", "")
prop_type = prop.get("type", "string")
if isinstance(prop_type, list):
prop_type = prop_type[0]
if prop_type in TYPE_MAPPING:
prop_type = TYPE_MAPPING[prop_type]
input_schema = prop if prop_type in COMPLEX_TYPES else None
parameters.append(
create_parameter(name, current_description, prop_type, name in required, input_schema)
)
return parameters
if schema.get("type") == "object" and "properties" in schema:
return process_properties(schema["properties"], schema.get("required", []))
return []

View File

@@ -0,0 +1,340 @@
import json
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.model import App
from models.tools import WorkflowToolProvider
from models.workflow import Workflow
from services.tools.tools_transform_service import ToolTransformService
class WorkflowToolManageService:
"""
Service class for managing workflow tools.
"""
@staticmethod
def create_workflow_tool(
*,
user_id: str,
tenant_id: str,
workflow_app_id: str,
name: str,
label: str,
icon: dict,
description: str,
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: list[str] | None = None,
):
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
)
.first()
)
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
app: App | None = db.session.query(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
if app is None:
raise ValueError(f"App {workflow_app_id} not found")
workflow: Workflow | None = app.workflow
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_app_id}")
with Session(db.engine, expire_on_commit=False) as session, session.begin():
workflow_tool_provider = WorkflowToolProvider(
tenant_id=tenant_id,
user_id=user_id,
app_id=workflow_app_id,
name=name,
label=label,
icon=json.dumps(icon),
description=description,
parameter_configuration=json.dumps(parameters),
privacy_policy=privacy_policy,
version=workflow.version,
)
session.add(workflow_tool_provider)
try:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
return {"result": "success"}
@classmethod
def update_workflow_tool(
cls,
user_id: str,
tenant_id: str,
workflow_tool_id: str,
name: str,
label: str,
icon: dict,
description: str,
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: list[str] | None = None,
):
"""
Update a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_tool_id: workflow tool id
:param name: name
:param label: label
:param icon: icon
:param description: description
:param parameters: parameters
:param privacy_policy: privacy policy
:param labels: labels
:return: the updated tool
"""
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id,
)
.first()
)
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} already exists")
workflow_tool_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
if workflow_tool_provider is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
app: App | None = (
db.session.query(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
)
if app is None:
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
workflow: Workflow | None = app.workflow
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
workflow_tool_provider.name = name
workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon)
workflow_tool_provider.description = description
workflow_tool_provider.parameter_configuration = json.dumps(parameters)
workflow_tool_provider.privacy_policy = privacy_policy
workflow_tool_provider.version = workflow.version
workflow_tool_provider.updated_at = datetime.now()
try:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
db.session.commit()
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
return {"result": "success"}
@classmethod
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
"""
List workflow tools.
:param user_id: the user id
:param tenant_id: the tenant id
:return: the list of tools
"""
db_tools = db.session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()
tools: list[WorkflowToolProviderController] = []
for provider in db_tools:
try:
tools.append(ToolTransformService.workflow_provider_to_controller(provider))
except Exception:
# skip deleted tools
pass
labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)])
result = []
for tool in tools:
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=tool, labels=labels.get(tool.provider_id, [])
)
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_tool_provider)
user_tool_provider.tools = [
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool.get_tools(tenant_id)[0],
labels=labels.get(tool.provider_id, []),
tenant_id=tenant_id,
)
]
result.append(user_tool_provider)
return result
@classmethod
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Delete a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_tool_id: the workflow tool id
"""
db.session.query(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
).delete()
db.session.commit()
return {"result": "success"}
@classmethod
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Get a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_tool_id: the workflow tool id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str):
"""
Get a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first()
)
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
"""
Get a workflow tool.
:db_tool: the database tool
:return: the tool
"""
if db_tool is None:
raise ValueError("Tool not found")
workflow_app: App | None = (
db.session.query(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
)
if workflow_app is None:
raise ValueError(f"App {db_tool.app_id} not found")
workflow = workflow_app.workflow
if not workflow:
raise ValueError("Workflow not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
if len(workflow_tools) == 0:
raise ValueError(f"Tool {db_tool.id} not found")
return {
"name": db_tool.name,
"label": db_tool.label,
"workflow_tool_id": db_tool.id,
"workflow_app_id": db_tool.app_id,
"icon": json.loads(db_tool.icon),
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool.get_tools(db_tool.tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool),
tenant_id=tenant_id,
),
"synced": workflow.version == db_tool.version,
"privacy_policy": db_tool.privacy_policy,
}
@classmethod
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]:
"""
List workflow tool provider tools.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_tool_id: the workflow tool id
:return: the list of tools
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
if db_tool is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
if len(workflow_tools) == 0:
raise ValueError(f"Tool {workflow_tool_id} not found")
return [
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool.get_tools(db_tool.tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool),
tenant_id=tenant_id,
)
]