dify
This commit is contained in:
474
dify/api/services/tools/api_tools_manage_service.py
Normal file
474
dify/api/services/tools/api_tools_manage_service.py
Normal file
@@ -0,0 +1,474 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from httpx import get
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ApiProviderSchemaType,
|
||||
)
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.encryption import create_tool_provider_encrypter
|
||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApiToolManageService:
|
||||
@staticmethod
|
||||
def parser_api_schema(schema: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
parse api schema to tool bundle
|
||||
"""
|
||||
try:
|
||||
warnings: dict[str, str] = {}
|
||||
try:
|
||||
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
credentials_schema = [
|
||||
ProviderConfig(
|
||||
name="auth_type",
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
required=True,
|
||||
default="none",
|
||||
options=[
|
||||
ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ProviderConfig.Option(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
|
||||
],
|
||||
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
|
||||
),
|
||||
ProviderConfig(
|
||||
name="api_key_header",
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"),
|
||||
default="api_key",
|
||||
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
|
||||
),
|
||||
ProviderConfig(
|
||||
name="api_key_value",
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
|
||||
default="",
|
||||
),
|
||||
]
|
||||
|
||||
return cast(
|
||||
Mapping,
|
||||
jsonable_encoder(
|
||||
{
|
||||
"schema_type": schema_type,
|
||||
"parameters_schema": tool_bundles,
|
||||
"credentials_schema": credentials_schema,
|
||||
"warning": warnings,
|
||||
}
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
convert schema to tool bundles
|
||||
|
||||
:return: the list of tool bundles, description
|
||||
"""
|
||||
try:
|
||||
return ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def create_api_tool_provider(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
schema_type: str,
|
||||
schema: str,
|
||||
privacy_policy: str,
|
||||
custom_disclaimer: str,
|
||||
labels: list[str],
|
||||
):
|
||||
"""
|
||||
create api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f"invalid schema type {schema}")
|
||||
|
||||
provider_name = provider_name.strip()
|
||||
|
||||
# check if the provider exists
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is not None:
|
||||
raise ValueError(f"provider {provider_name} already exists")
|
||||
|
||||
# parse openapi to tool bundle
|
||||
extra_info: dict[str, str] = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
if len(tool_bundles) > 100:
|
||||
raise ValueError("the number of apis should be less than 100")
|
||||
|
||||
# create db provider
|
||||
db_provider = ApiToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=provider_name,
|
||||
icon=json.dumps(icon),
|
||||
schema=schema,
|
||||
description=extra_info.get("description", ""),
|
||||
schema_type_str=schema_type,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str="{}",
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
)
|
||||
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
|
||||
# load tools into provider entity
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# encrypt credentials
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
)
|
||||
db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
|
||||
|
||||
db.session.add(db_provider)
|
||||
db.session.commit()
|
||||
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str):
|
||||
"""
|
||||
get api tool provider remote schema
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko)"
|
||||
" Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
|
||||
"Accept": "*/*",
|
||||
}
|
||||
|
||||
try:
|
||||
response = get(url, headers=headers, timeout=10)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Got status code {response.status_code}")
|
||||
schema = response.text
|
||||
|
||||
# try to parse schema, avoid SSRF attack
|
||||
ApiToolManageService.parser_api_schema(schema)
|
||||
except Exception:
|
||||
logger.exception("parse api schema error")
|
||||
raise ValueError("invalid schema, please check the url you provided")
|
||||
|
||||
return {"schema": schema}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
list api tool provider tools
|
||||
"""
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
|
||||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
|
||||
return [
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool_bundle,
|
||||
tenant_id=tenant_id,
|
||||
labels=labels,
|
||||
)
|
||||
for tool_bundle in provider.tools
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def update_api_tool_provider(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
original_provider: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
schema_type: str,
|
||||
schema: str,
|
||||
privacy_policy: str,
|
||||
custom_disclaimer: str,
|
||||
labels: list[str],
|
||||
):
|
||||
"""
|
||||
update api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f"invalid schema type {schema}")
|
||||
|
||||
provider_name = provider_name.strip()
|
||||
|
||||
# check if the provider exists
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == original_provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"api provider {provider_name} does not exists")
|
||||
# parse openapi to tool bundle
|
||||
extra_info: dict[str, str] = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
# update db provider
|
||||
provider.name = provider_name
|
||||
provider.icon = json.dumps(icon)
|
||||
provider.schema = schema
|
||||
provider.description = extra_info.get("description", "")
|
||||
provider.schema_type_str = ApiProviderSchemaType.OPENAPI
|
||||
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
|
||||
provider.privacy_policy = privacy_policy
|
||||
provider.custom_disclaimer = custom_disclaimer
|
||||
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(provider, auth_type)
|
||||
# load tools into provider entity
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# get original credentials if exists
|
||||
encrypter, cache = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
)
|
||||
|
||||
original_credentials = encrypter.decrypt(provider.credentials)
|
||||
masked_credentials = encrypter.mask_plugin_credentials(original_credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = original_credentials[name]
|
||||
|
||||
credentials = dict(encrypter.encrypt(credentials))
|
||||
provider.credentials_str = json.dumps(credentials)
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
# delete cache
|
||||
cache.delete()
|
||||
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_api_tool_provider(user_id: str, tenant_id: str, provider: str):
|
||||
"""
|
||||
get api tool provider
|
||||
"""
|
||||
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
|
||||
|
||||
@staticmethod
|
||||
def test_api_tool_preview(
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
tool_name: str,
|
||||
credentials: dict,
|
||||
parameters: dict,
|
||||
schema_type: str,
|
||||
schema: str,
|
||||
):
|
||||
"""
|
||||
test api tool before adding api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f"invalid schema type {schema_type}")
|
||||
|
||||
try:
|
||||
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
|
||||
except Exception:
|
||||
raise ValueError("invalid schema")
|
||||
|
||||
# get tool bundle
|
||||
tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)
|
||||
if tool_bundle is None:
|
||||
raise ValueError(f"invalid tool name {tool_name}")
|
||||
|
||||
db_provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not db_provider:
|
||||
# create a fake db provider
|
||||
db_provider = ApiToolProvider(
|
||||
tenant_id="",
|
||||
user_id="",
|
||||
name="",
|
||||
icon="",
|
||||
schema=schema,
|
||||
description="",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
|
||||
# load tools into provider entity
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# decrypt credentials
|
||||
if db_provider.id:
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
)
|
||||
decrypted_credentials = encrypter.decrypt(credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
masked_credentials = encrypter.mask_plugin_credentials(decrypted_credentials)
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = decrypted_credentials[name]
|
||||
|
||||
try:
|
||||
provider_controller.validate_credentials_format(credentials)
|
||||
# get tool
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
result = tool.validate_credentials(credentials, parameters)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
return {"result": result or "empty response"}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
list api tools
|
||||
"""
|
||||
# get all api providers
|
||||
db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
|
||||
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
for provider in db_providers:
|
||||
# convert provider controller to user provider
|
||||
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
|
||||
labels = ToolLabelManager.get_tool_labels(provider_controller)
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller, db_provider=provider, decrypt_credentials=True
|
||||
)
|
||||
user_provider.labels = labels
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_provider)
|
||||
|
||||
tools = provider_controller.get_tools(tenant_id=tenant_id)
|
||||
|
||||
for tool in tools or []:
|
||||
user_provider.tools.append(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tenant_id=tenant_id, tool=tool, labels=labels
|
||||
)
|
||||
)
|
||||
|
||||
result.append(user_provider)
|
||||
|
||||
return result
|
||||
727
dify/api/services/tools/builtin_tools_manage_service.py
Normal file
727
dify/api/services/tools/builtin_tools_manage_service.py
Normal file
@@ -0,0 +1,727 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.entities.api_entities import (
|
||||
ToolApiEntity,
|
||||
ToolProviderApiEntity,
|
||||
ToolProviderCredentialApiEntity,
|
||||
ToolProviderCredentialInfoApiEntity,
|
||||
)
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ToolProviderID
|
||||
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BuiltinToolManageService:
|
||||
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
|
||||
|
||||
@staticmethod
|
||||
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
|
||||
"""
|
||||
delete custom oauth client params
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider)
|
||||
with Session(db.engine) as session:
|
||||
session.query(ToolOAuthTenantClient).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=tool_provider.provider_name,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str):
|
||||
"""
|
||||
get builtin tool provider oauth client schema
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
|
||||
tenant_id, provider.plugin_unique_identifier
|
||||
)
|
||||
|
||||
is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled(
|
||||
tenant_id, provider_name
|
||||
)
|
||||
is_system_oauth_params_exists = verified and BuiltinToolManageService.is_oauth_system_client_exists(
|
||||
provider_name
|
||||
)
|
||||
result = {
|
||||
"schema": provider.get_oauth_client_schema(),
|
||||
"is_oauth_custom_client_enabled": is_oauth_custom_client_enabled,
|
||||
"is_system_oauth_params_exists": is_system_oauth_params_exists,
|
||||
"client_params": BuiltinToolManageService.get_custom_oauth_client_params(tenant_id, provider_name),
|
||||
"redirect_uri": f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback",
|
||||
}
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
list builtin tool provider tools
|
||||
|
||||
:param tenant_id: the id of the tenant
|
||||
:param provider: the name of the provider
|
||||
|
||||
:return: the list of tools
|
||||
"""
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tools = provider_controller.get_tools()
|
||||
|
||||
result: list[ToolApiEntity] = []
|
||||
for tool in tools or []:
|
||||
result.append(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool,
|
||||
tenant_id=tenant_id,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_info(tenant_id: str, provider: str):
|
||||
"""
|
||||
get builtin tool provider info
|
||||
"""
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
# check if user has added the provider
|
||||
builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
|
||||
if builtin_provider is None:
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
|
||||
entity = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=builtin_provider,
|
||||
decrypt_credentials=True,
|
||||
)
|
||||
|
||||
entity.original_credentials = {}
|
||||
return entity
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str):
|
||||
"""
|
||||
list builtin provider credentials schema
|
||||
|
||||
:param credential_type: credential type
|
||||
:param provider_name: the name of the provider
|
||||
:param tenant_id: the id of the tenant
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
return provider.get_credentials_schema_by_type(credential_type)
|
||||
|
||||
@staticmethod
|
||||
def update_builtin_tool_provider(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
credential_id: str,
|
||||
credentials: dict | None = None,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""
|
||||
update builtin tool provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# get if the provider exists
|
||||
db_provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if db_provider is None:
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
|
||||
try:
|
||||
if CredentialType.of(db_provider.credential_type).is_editable() and credentials:
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider} does not need credentials")
|
||||
|
||||
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
|
||||
tenant_id, db_provider, provider, provider_controller
|
||||
)
|
||||
|
||||
original_credentials = encrypter.decrypt(db_provider.credentials)
|
||||
new_credentials: dict = {
|
||||
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
|
||||
if CredentialType.of(db_provider.credential_type).is_validate_allowed():
|
||||
provider_controller.validate_credentials(user_id, new_credentials)
|
||||
|
||||
# encrypt credentials
|
||||
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
|
||||
|
||||
cache.delete()
|
||||
|
||||
# update name if provided
|
||||
if name and name != db_provider.name:
|
||||
# check if the name is already used
|
||||
if session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.name == name,
|
||||
)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"the credential name '{name}' is already used")
|
||||
|
||||
db_provider.name = name
|
||||
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def add_builtin_tool_provider(
|
||||
user_id: str,
|
||||
api_type: CredentialType,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
credentials: dict,
|
||||
expires_at: int = -1,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""
|
||||
add builtin tool provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider} does not need credentials")
|
||||
|
||||
provider_count = (
|
||||
session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
|
||||
)
|
||||
|
||||
# check if the provider count is reached the limit
|
||||
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
|
||||
raise ValueError(f"you have reached the maximum number of providers for {provider}")
|
||||
|
||||
# validate credentials if allowed
|
||||
if CredentialType.of(api_type).is_validate_allowed():
|
||||
provider_controller.validate_credentials(user_id, credentials)
|
||||
|
||||
# generate name if not provided
|
||||
if name is None or name == "":
|
||||
name = BuiltinToolManageService.generate_builtin_tool_provider_name(
|
||||
session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
|
||||
)
|
||||
else:
|
||||
# check if the name is already used
|
||||
if session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.name == name,
|
||||
)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"the credential name '{name}' is already used")
|
||||
|
||||
# create encrypter
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(api_type)
|
||||
],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
db_provider = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
|
||||
credential_type=api_type.value,
|
||||
name=name,
|
||||
expires_at=expires_at if expires_at is not None else -1,
|
||||
)
|
||||
|
||||
session.add(db_provider)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def create_tool_encrypter(
|
||||
tenant_id: str,
|
||||
db_provider: BuiltinToolProvider,
|
||||
provider: str,
|
||||
provider_controller: BuiltinToolProviderController,
|
||||
):
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type)
|
||||
],
|
||||
cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id),
|
||||
)
|
||||
return encrypter, cache
|
||||
|
||||
@staticmethod
|
||||
def generate_builtin_tool_provider_name(
|
||||
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
|
||||
) -> str:
|
||||
db_providers = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
credential_type=credential_type.value,
|
||||
)
|
||||
.order_by(BuiltinToolProvider.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return generate_incremental_name(
|
||||
[provider.name for provider in db_providers],
|
||||
f"{credential_type.get_name()}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_credentials(
|
||||
tenant_id: str, provider_name: str
|
||||
) -> list[ToolProviderCredentialApiEntity]:
|
||||
"""
|
||||
get builtin tool provider credentials
|
||||
"""
|
||||
with db.session.no_autoflush:
|
||||
providers = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider_name)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(providers) == 0:
|
||||
return []
|
||||
|
||||
default_provider = providers[0]
|
||||
default_provider.is_default = True
|
||||
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
|
||||
|
||||
credentials: list[ToolProviderCredentialApiEntity] = []
|
||||
for provider in providers:
|
||||
encrypter, _ = BuiltinToolManageService.create_tool_encrypter(
|
||||
tenant_id, provider, provider.provider, provider_controller
|
||||
)
|
||||
decrypt_credential = encrypter.mask_plugin_credentials(encrypter.decrypt(provider.credentials))
|
||||
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
|
||||
provider=provider,
|
||||
credentials=dict(decrypt_credential),
|
||||
)
|
||||
credentials.append(credential_entity)
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity:
|
||||
"""
|
||||
get builtin tool provider credential info
|
||||
"""
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
supported_credential_types = provider_controller.get_supported_credential_types()
|
||||
credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
|
||||
credential_info = ToolProviderCredentialInfoApiEntity(
|
||||
supported_credential_types=supported_credential_types,
|
||||
is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
return credential_info
|
||||
|
||||
@staticmethod
|
||||
def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str):
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
db_provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_provider is None:
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
|
||||
session.delete(db_provider)
|
||||
session.commit()
|
||||
|
||||
# delete cache
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
_, cache = BuiltinToolManageService.create_tool_encrypter(
|
||||
tenant_id, db_provider, provider, provider_controller
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
|
||||
"""
|
||||
set default provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# get provider
|
||||
target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
# clear default provider
|
||||
session.query(BuiltinToolProvider).filter_by(
|
||||
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
|
||||
).update({"is_default": False})
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def is_oauth_system_client_exists(provider_name: str) -> bool:
|
||||
"""
|
||||
check if oauth system client exists
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider_name)
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
system_client: ToolOAuthSystemClient | None = (
|
||||
session.query(ToolOAuthSystemClient)
|
||||
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
|
||||
.first()
|
||||
)
|
||||
return system_client is not None
|
||||
|
||||
@staticmethod
|
||||
def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
|
||||
"""
|
||||
check if oauth custom client is enabled
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider)
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
user_client: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=tool_provider.provider_name,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
enabled=True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return user_client is not None and user_client.enabled
|
||||
|
||||
@staticmethod
|
||||
def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
get builtin tool provider
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
user_client: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=tool_provider.provider_name,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
enabled=True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
oauth_params: Mapping[str, Any] | None = None
|
||||
if user_client:
|
||||
oauth_params = encrypter.decrypt(user_client.oauth_params)
|
||||
return oauth_params
|
||||
|
||||
# only verified provider can use official oauth client
|
||||
is_verified = not isinstance(
|
||||
provider_controller, PluginToolProviderController
|
||||
) or PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
|
||||
if not is_verified:
|
||||
return oauth_params
|
||||
|
||||
system_client: ToolOAuthSystemClient | None = (
|
||||
session.query(ToolOAuthSystemClient)
|
||||
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
|
||||
.first()
|
||||
)
|
||||
if system_client:
|
||||
try:
|
||||
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||
|
||||
return oauth_params
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_icon(provider: str):
|
||||
"""
|
||||
get tool provider icon and it's mimetype
|
||||
"""
|
||||
icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider)
|
||||
icon_bytes = Path(icon_path).read_bytes()
|
||||
|
||||
return icon_bytes, mime_type
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
list builtin tools
|
||||
"""
|
||||
# get all builtin providers
|
||||
provider_controllers = ToolManager.list_builtin_providers(tenant_id)
|
||||
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
|
||||
|
||||
# rewrite db_providers
|
||||
for db_provider in db_providers:
|
||||
db_provider.provider = str(ToolProviderID(db_provider.provider))
|
||||
|
||||
# find provider
|
||||
def find_provider(provider):
|
||||
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
|
||||
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.entity.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=find_provider(provider_controller.entity.identity.name),
|
||||
decrypt_credentials=True,
|
||||
)
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
|
||||
|
||||
tools = provider_controller.get_tools()
|
||||
for tool in tools or []:
|
||||
user_builtin_provider.tools.append(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
)
|
||||
|
||||
result.append(user_builtin_provider)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
|
||||
"""
|
||||
This method is used to fetch the builtin provider from the database
|
||||
1.if the default provider exists, return the default provider
|
||||
2.if the default provider does not exist, return the oldest provider
|
||||
"""
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
try:
|
||||
full_provider_name = provider_name
|
||||
provider_id_entity = ToolProviderID(provider_name)
|
||||
provider_name = provider_id_entity.provider_name
|
||||
|
||||
if provider_id_entity.organization != "langgenius":
|
||||
provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == full_provider_name,
|
||||
)
|
||||
.order_by(
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == provider_name)
|
||||
| (BuiltinToolProvider.provider == full_provider_name),
|
||||
)
|
||||
.order_by(
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
return None
|
||||
|
||||
provider.provider = ToolProviderID(provider.provider).to_string()
|
||||
return provider
|
||||
except Exception:
|
||||
# it's an old provider without organization
|
||||
return (
|
||||
session.query(BuiltinToolProvider)
|
||||
.where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
|
||||
.order_by(
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def save_custom_oauth_client_params(
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
client_params: dict | None = None,
|
||||
enable_oauth_custom_client: bool | None = None,
|
||||
):
|
||||
"""
|
||||
setup oauth custom client
|
||||
"""
|
||||
if client_params is None and enable_oauth_custom_client is None:
|
||||
return {"result": "success"}
|
||||
|
||||
tool_provider = ToolProviderID(provider)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if not provider_controller:
|
||||
raise ToolProviderNotFoundError(f"Provider {provider} not found")
|
||||
|
||||
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
|
||||
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
custom_client_params = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=tool_provider.provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# if the record does not exist, create a basic record
|
||||
if custom_client_params is None:
|
||||
custom_client_params = ToolOAuthTenantClient(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=tool_provider.provider_name,
|
||||
)
|
||||
session.add(custom_client_params)
|
||||
|
||||
if client_params is not None:
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
original_params = encrypter.decrypt(custom_client_params.oauth_params)
|
||||
new_params = {
|
||||
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
||||
for key, value in client_params.items()
|
||||
}
|
||||
custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
|
||||
|
||||
if enable_oauth_custom_client is not None:
|
||||
custom_client_params.enabled = enable_oauth_custom_client
|
||||
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_custom_oauth_client_params(tenant_id: str, provider: str):
|
||||
"""
|
||||
get custom oauth client params
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
tool_provider = ToolProviderID(provider)
|
||||
custom_oauth_client_params: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=tool_provider.provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if custom_oauth_client_params is None:
|
||||
return {}
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if not provider_controller:
|
||||
raise ToolProviderNotFoundError(f"Provider {provider} not found")
|
||||
|
||||
if not isinstance(provider_controller, BuiltinToolProviderController):
|
||||
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return encrypter.mask_plugin_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))
|
||||
734
dify/api/services/tools/mcp_tools_manage_service.py
Normal file
734
dify/api/services/tools/mcp_tools_manage_service.py
Normal file
@@ -0,0 +1,734 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.utils.encryption import ProviderConfigEncrypter
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
|
||||
CLIENT_NAME = "Dify"
|
||||
EMPTY_TOOLS_JSON = "[]"
|
||||
EMPTY_CREDENTIALS_JSON = "{}"
|
||||
|
||||
|
||||
class OAuthDataType(StrEnum):
|
||||
"""Types of OAuth data that can be saved."""
|
||||
|
||||
TOKENS = "tokens"
|
||||
CLIENT_INFO = "client_info"
|
||||
CODE_VERIFIER = "code_verifier"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
class ReconnectResult(BaseModel):
|
||||
"""Result of reconnecting to an MCP provider"""
|
||||
|
||||
authed: bool = Field(description="Whether the provider is authenticated")
|
||||
tools: str = Field(description="JSON string of tool list")
|
||||
encrypted_credentials: str = Field(description="JSON string of encrypted credentials")
|
||||
|
||||
|
||||
class ServerUrlValidationResult(BaseModel):
|
||||
"""Result of server URL validation check"""
|
||||
|
||||
needs_validation: bool
|
||||
validation_passed: bool = False
|
||||
reconnect_result: ReconnectResult | None = None
|
||||
encrypted_server_url: str | None = None
|
||||
server_url_hash: str | None = None
|
||||
|
||||
@property
|
||||
def should_update_server_url(self) -> bool:
|
||||
"""Check if server URL should be updated based on validation result"""
|
||||
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
|
||||
|
||||
|
||||
class MCPToolManageService:
|
||||
"""Service class for managing MCP tools and providers."""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
|
||||
# ========== Provider CRUD Operations ==========
|
||||
|
||||
def get_provider(
|
||||
self, *, provider_id: str | None = None, server_identifier: str | None = None, tenant_id: str
|
||||
) -> MCPToolProvider:
|
||||
"""
|
||||
Get MCP provider by ID or server identifier.
|
||||
|
||||
Args:
|
||||
provider_id: Provider ID (UUID)
|
||||
server_identifier: Server identifier
|
||||
tenant_id: Tenant ID
|
||||
|
||||
Returns:
|
||||
MCPToolProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider not found
|
||||
"""
|
||||
if server_identifier:
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier
|
||||
)
|
||||
else:
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id
|
||||
)
|
||||
|
||||
provider = self._session.scalar(stmt)
|
||||
if not provider:
|
||||
raise ValueError("MCP tool not found")
|
||||
return provider
|
||||
|
||||
def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity:
|
||||
"""Get provider entity by ID or server identifier."""
|
||||
if by_server_id:
|
||||
db_provider = self.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
|
||||
else:
|
||||
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
return db_provider.to_entity()
|
||||
|
||||
def create_provider(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
name: str,
|
||||
server_url: str,
|
||||
user_id: str,
|
||||
icon: str,
|
||||
icon_type: str,
|
||||
icon_background: str,
|
||||
server_identifier: str,
|
||||
configuration: MCPConfiguration,
|
||||
authentication: MCPAuthentication | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> ToolProviderApiEntity:
|
||||
"""Create a new MCP provider."""
|
||||
# Validate URL format
|
||||
if not self._is_valid_url(server_url):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
|
||||
# Check for existing provider
|
||||
self._check_provider_exists(tenant_id, name, server_url_hash, server_identifier)
|
||||
|
||||
# Encrypt sensitive data
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
|
||||
encrypted_credentials = None
|
||||
if authentication is not None and authentication.client_id:
|
||||
encrypted_credentials = self._build_and_encrypt_credentials(
|
||||
authentication.client_id, authentication.client_secret, tenant_id
|
||||
)
|
||||
|
||||
# Create provider
|
||||
mcp_tool = MCPToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
server_url=encrypted_server_url,
|
||||
server_url_hash=server_url_hash,
|
||||
user_id=user_id,
|
||||
authed=False,
|
||||
tools=EMPTY_TOOLS_JSON,
|
||||
icon=self._prepare_icon(icon, icon_type, icon_background),
|
||||
server_identifier=server_identifier,
|
||||
timeout=configuration.timeout,
|
||||
sse_read_timeout=configuration.sse_read_timeout,
|
||||
encrypted_headers=encrypted_headers,
|
||||
encrypted_credentials=encrypted_credentials,
|
||||
)
|
||||
|
||||
self._session.add(mcp_tool)
|
||||
self._session.flush()
|
||||
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||
return mcp_providers
|
||||
|
||||
def update_provider(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
provider_id: str,
|
||||
name: str,
|
||||
server_url: str,
|
||||
icon: str,
|
||||
icon_type: str,
|
||||
icon_background: str,
|
||||
server_identifier: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
configuration: MCPConfiguration,
|
||||
authentication: MCPAuthentication | None = None,
|
||||
validation_result: ServerUrlValidationResult | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update an MCP provider.
|
||||
|
||||
Args:
|
||||
validation_result: Pre-validation result from validate_server_url_change.
|
||||
If provided and contains reconnect_result, it will be used
|
||||
instead of performing network operations.
|
||||
"""
|
||||
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
# Check for duplicate name (excluding current provider)
|
||||
if name != mcp_provider.name:
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
MCPToolProvider.name == name,
|
||||
MCPToolProvider.id != provider_id,
|
||||
)
|
||||
existing_provider = self._session.scalar(stmt)
|
||||
if existing_provider:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
|
||||
# Get URL update data from validation result
|
||||
encrypted_server_url = None
|
||||
server_url_hash = None
|
||||
reconnect_result = None
|
||||
|
||||
if validation_result and validation_result.encrypted_server_url:
|
||||
# Use all data from validation result
|
||||
encrypted_server_url = validation_result.encrypted_server_url
|
||||
server_url_hash = validation_result.server_url_hash
|
||||
reconnect_result = validation_result.reconnect_result
|
||||
|
||||
try:
|
||||
# Update basic fields
|
||||
mcp_provider.updated_at = datetime.now()
|
||||
mcp_provider.name = name
|
||||
mcp_provider.icon = self._prepare_icon(icon, icon_type, icon_background)
|
||||
mcp_provider.server_identifier = server_identifier
|
||||
|
||||
# Update server URL if changed
|
||||
if encrypted_server_url and server_url_hash:
|
||||
mcp_provider.server_url = encrypted_server_url
|
||||
mcp_provider.server_url_hash = server_url_hash
|
||||
|
||||
if reconnect_result:
|
||||
mcp_provider.authed = reconnect_result.authed
|
||||
mcp_provider.tools = reconnect_result.tools
|
||||
mcp_provider.encrypted_credentials = reconnect_result.encrypted_credentials
|
||||
|
||||
# Update optional configuration fields
|
||||
self._update_optional_fields(mcp_provider, configuration)
|
||||
|
||||
# Update headers if provided
|
||||
if headers is not None:
|
||||
mcp_provider.encrypted_headers = self._process_headers(headers, mcp_provider, tenant_id)
|
||||
|
||||
# Update credentials if provided
|
||||
if authentication and authentication.client_id:
|
||||
mcp_provider.encrypted_credentials = self._process_credentials(authentication, mcp_provider, tenant_id)
|
||||
|
||||
# Flush changes to database
|
||||
self._session.flush()
|
||||
except IntegrityError as e:
|
||||
self._handle_integrity_error(e, name, server_url, server_identifier)
|
||||
|
||||
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
|
||||
"""Delete an MCP provider."""
|
||||
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
self._session.delete(mcp_tool)
|
||||
|
||||
def list_providers(
|
||||
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
|
||||
) -> list[ToolProviderApiEntity]:
|
||||
"""List all MCP providers for a tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
for_list: If True, return provider ID; if False, return server identifier
|
||||
include_sensitive: If False, skip expensive decryption operations (default: True for backward compatibility)
|
||||
"""
|
||||
from models.account import Account
|
||||
|
||||
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
|
||||
mcp_providers = self._session.scalars(stmt).all()
|
||||
|
||||
if not mcp_providers:
|
||||
return []
|
||||
|
||||
# Batch query all users to avoid N+1 problem
|
||||
user_ids = {provider.user_id for provider in mcp_providers}
|
||||
users = self._session.query(Account).where(Account.id.in_(user_ids)).all()
|
||||
user_name_map = {user.id: user.name for user in users}
|
||||
|
||||
return [
|
||||
ToolTransformService.mcp_provider_to_user_provider(
|
||||
provider,
|
||||
for_list=for_list,
|
||||
user_name=user_name_map.get(provider.user_id),
|
||||
include_sensitive=include_sensitive,
|
||||
)
|
||||
for provider in mcp_providers
|
||||
]
|
||||
|
||||
# ========== Tool Operations ==========
|
||||
|
||||
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
|
||||
"""List tools from remote MCP server."""
|
||||
# Load provider and convert to entity
|
||||
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
provider_entity = db_provider.to_entity()
|
||||
|
||||
# Verify authentication
|
||||
if not provider_entity.authed:
|
||||
raise ValueError("Please auth the tool first")
|
||||
|
||||
# Prepare headers with auth token
|
||||
headers = self._prepare_auth_headers(provider_entity)
|
||||
|
||||
# Retrieve tools from remote server
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
try:
|
||||
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||
|
||||
# Update database with retrieved tools
|
||||
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
db_provider.authed = True
|
||||
db_provider.updated_at = datetime.now()
|
||||
self._session.flush()
|
||||
|
||||
# Build API response
|
||||
return self._build_tool_provider_response(db_provider, provider_entity, tools)
|
||||
|
||||
# ========== OAuth and Credentials Operations ==========
|
||||
|
||||
def update_provider_credentials(
|
||||
self, *, provider_id: str, tenant_id: str, credentials: dict[str, Any], authed: bool | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Update provider credentials with encryption.
|
||||
|
||||
Args:
|
||||
provider_id: Provider ID
|
||||
tenant_id: Tenant ID
|
||||
credentials: Credentials to save
|
||||
authed: Whether provider is authenticated (None means keep current state)
|
||||
"""
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
|
||||
# Get provider from current session
|
||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
# Encrypt new credentials
|
||||
provider_controller = MCPToolProviderController.from_db(provider)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=provider.tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_config_cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||
|
||||
# Update provider
|
||||
provider.updated_at = datetime.now()
|
||||
provider.encrypted_credentials = json.dumps({**provider.credentials, **encrypted_credentials})
|
||||
|
||||
if authed is not None:
|
||||
provider.authed = authed
|
||||
if not authed:
|
||||
provider.tools = EMPTY_TOOLS_JSON
|
||||
|
||||
# Flush changes to database
|
||||
self._session.flush()
|
||||
|
||||
def save_oauth_data(
|
||||
self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: OAuthDataType = OAuthDataType.MIXED
|
||||
) -> None:
|
||||
"""
|
||||
Save OAuth-related data (tokens, client info, code verifier).
|
||||
|
||||
Args:
|
||||
provider_id: Provider ID
|
||||
tenant_id: Tenant ID
|
||||
data: Data to save (tokens, client info, or code verifier)
|
||||
data_type: Type of OAuth data to save
|
||||
"""
|
||||
# Determine if this makes the provider authenticated
|
||||
authed = (
|
||||
data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None
|
||||
)
|
||||
|
||||
# update_provider_credentials will validate provider existence
|
||||
self.update_provider_credentials(provider_id=provider_id, tenant_id=tenant_id, credentials=data, authed=authed)
|
||||
|
||||
def clear_provider_credentials(self, *, provider_id: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Clear all credentials for a provider.
|
||||
|
||||
Args:
|
||||
provider_id: Provider ID
|
||||
tenant_id: Tenant ID
|
||||
"""
|
||||
# Get provider from current session
|
||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
provider.tools = EMPTY_TOOLS_JSON
|
||||
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
|
||||
provider.updated_at = datetime.now()
|
||||
provider.authed = False
|
||||
|
||||
# ========== Private Helper Methods ==========
|
||||
|
||||
def _check_provider_exists(self, tenant_id: str, name: str, server_url_hash: str, server_identifier: str) -> None:
|
||||
"""Check if provider with same attributes already exists."""
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
or_(
|
||||
MCPToolProvider.name == name,
|
||||
MCPToolProvider.server_url_hash == server_url_hash,
|
||||
MCPToolProvider.server_identifier == server_identifier,
|
||||
),
|
||||
)
|
||||
existing_provider = self._session.scalar(stmt)
|
||||
|
||||
if existing_provider:
|
||||
if existing_provider.name == name:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
if existing_provider.server_url_hash == server_url_hash:
|
||||
raise ValueError("MCP tool with this server URL already exists")
|
||||
if existing_provider.server_identifier == server_identifier:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
|
||||
def _prepare_icon(self, icon: str, icon_type: str, icon_background: str) -> str:
|
||||
"""Prepare icon data for storage."""
|
||||
if icon_type == "emoji":
|
||||
return json.dumps({"content": icon, "background": icon_background})
|
||||
return icon
|
||||
|
||||
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> Mapping[str, str]:
|
||||
"""Encrypt specified fields in a dictionary.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing data to encrypt
|
||||
secret_fields: List of field names to encrypt
|
||||
tenant_id: Tenant ID for encryption
|
||||
|
||||
Returns:
|
||||
JSON string of encrypted data
|
||||
"""
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
# Create config for secret fields
|
||||
config = [
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=field) for field in secret_fields
|
||||
]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
encrypted_data = encrypter_instance.encrypt(data)
|
||||
return encrypted_data
|
||||
|
||||
def _prepare_encrypted_dict(self, headers: dict[str, str], tenant_id: str) -> str:
|
||||
"""Encrypt headers and prepare for storage."""
|
||||
# All headers are treated as secret
|
||||
return json.dumps(self._encrypt_dict_fields(headers, list(headers.keys()), tenant_id))
|
||||
|
||||
def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]:
|
||||
"""Prepare headers with OAuth token if available."""
|
||||
headers = provider_entity.decrypt_headers()
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
return headers
|
||||
|
||||
def _retrieve_remote_mcp_tools(
|
||||
self,
|
||||
server_url: str,
|
||||
headers: dict[str, str],
|
||||
provider_entity: MCPProviderEntity,
|
||||
):
|
||||
"""Retrieve tools from remote MCP server."""
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
) as mcp_client:
|
||||
return mcp_client.list_tools()
|
||||
|
||||
def execute_auth_actions(self, auth_result: Any) -> dict[str, str]:
|
||||
"""
|
||||
Execute the actions returned by the auth function.
|
||||
|
||||
This method processes the AuthResult and performs the necessary database operations.
|
||||
|
||||
Args:
|
||||
auth_result: The result from the auth function
|
||||
|
||||
Returns:
|
||||
The response from the auth result
|
||||
"""
|
||||
from core.mcp.entities import AuthAction, AuthActionType
|
||||
|
||||
action: AuthAction
|
||||
for action in auth_result.actions:
|
||||
if action.provider_id is None or action.tenant_id is None:
|
||||
continue
|
||||
|
||||
if action.action_type == AuthActionType.SAVE_CLIENT_INFO:
|
||||
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CLIENT_INFO)
|
||||
elif action.action_type == AuthActionType.SAVE_TOKENS:
|
||||
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.TOKENS)
|
||||
elif action.action_type == AuthActionType.SAVE_CODE_VERIFIER:
|
||||
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CODE_VERIFIER)
|
||||
|
||||
return auth_result.response
|
||||
|
||||
def auth_with_actions(
|
||||
self,
|
||||
provider_entity: MCPProviderEntity,
|
||||
authorization_code: str | None = None,
|
||||
resource_metadata_url: str | None = None,
|
||||
scope_hint: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Perform authentication and execute all resulting actions.
|
||||
|
||||
This method is used by MCPClientWithAuthRetry for automatic re-authentication.
|
||||
|
||||
Args:
|
||||
provider_entity: The MCP provider entity
|
||||
authorization_code: Optional authorization code
|
||||
resource_metadata_url: Optional Protected Resource Metadata URL from WWW-Authenticate
|
||||
scope_hint: Optional scope hint from WWW-Authenticate header
|
||||
|
||||
Returns:
|
||||
Response dictionary from auth result
|
||||
"""
|
||||
auth_result = auth(
|
||||
provider_entity,
|
||||
authorization_code,
|
||||
resource_metadata_url=resource_metadata_url,
|
||||
scope_hint=scope_hint,
|
||||
)
|
||||
return self.execute_auth_actions(auth_result)
|
||||
|
||||
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
|
||||
"""Attempt to reconnect to MCP provider with new server URL."""
|
||||
provider_entity = provider.to_entity()
|
||||
headers = provider_entity.headers
|
||||
|
||||
try:
|
||||
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
|
||||
return ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||
)
|
||||
except MCPAuthError:
|
||||
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||
|
||||
def validate_server_url_change(
|
||||
self, *, tenant_id: str, provider_id: str, new_server_url: str
|
||||
) -> ServerUrlValidationResult:
|
||||
"""
|
||||
Validate server URL change by attempting to connect to the new server.
|
||||
This method should be called BEFORE update_provider to perform network operations
|
||||
outside of the database transaction.
|
||||
|
||||
Returns:
|
||||
ServerUrlValidationResult: Validation result with connection status and tools if successful
|
||||
"""
|
||||
# Handle hidden/unchanged URL
|
||||
if UNCHANGED_SERVER_URL_PLACEHOLDER in new_server_url:
|
||||
return ServerUrlValidationResult(needs_validation=False)
|
||||
|
||||
# Validate URL format
|
||||
if not self._is_valid_url(new_server_url):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
|
||||
# Always encrypt and hash the URL
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
|
||||
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
|
||||
|
||||
# Get current provider
|
||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
# Check if URL is actually different
|
||||
if new_server_url_hash == provider.server_url_hash:
|
||||
# URL hasn't changed, but still return the encrypted data
|
||||
return ServerUrlValidationResult(
|
||||
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
|
||||
)
|
||||
|
||||
# Perform validation by attempting to connect
|
||||
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
|
||||
return ServerUrlValidationResult(
|
||||
needs_validation=True,
|
||||
validation_passed=True,
|
||||
reconnect_result=reconnect_result,
|
||||
encrypted_server_url=encrypted_server_url,
|
||||
server_url_hash=new_server_url_hash,
|
||||
)
|
||||
|
||||
def _build_tool_provider_response(
|
||||
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
||||
) -> ToolProviderApiEntity:
|
||||
"""Build API response for tool provider."""
|
||||
user = db_provider.load_user()
|
||||
response = provider_entity.to_api_response(
|
||||
user_name=user.name if user else None,
|
||||
)
|
||||
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools)
|
||||
response["plugin_unique_identifier"] = provider_entity.provider_id
|
||||
return ToolProviderApiEntity(**response)
|
||||
|
||||
def _handle_integrity_error(
|
||||
self, error: IntegrityError, name: str, server_url: str, server_identifier: str
|
||||
) -> None:
|
||||
"""Handle database integrity errors with user-friendly messages."""
|
||||
error_msg = str(error.orig)
|
||||
if "unique_mcp_provider_name" in error_msg:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
if "unique_mcp_provider_server_url" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
if "unique_mcp_provider_server_identifier" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
raise
|
||||
|
||||
def _is_valid_url(self, url: str) -> bool:
|
||||
"""Validate URL format."""
|
||||
if not url:
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
def _update_optional_fields(self, mcp_provider: MCPToolProvider, configuration: MCPConfiguration) -> None:
|
||||
"""Update optional configuration fields using setattr for cleaner code."""
|
||||
field_mapping = {"timeout": configuration.timeout, "sse_read_timeout": configuration.sse_read_timeout}
|
||||
|
||||
for field, value in field_mapping.items():
|
||||
if value is not None:
|
||||
setattr(mcp_provider, field, value)
|
||||
|
||||
def _process_headers(self, headers: dict[str, str], mcp_provider: MCPToolProvider, tenant_id: str) -> str | None:
|
||||
"""Process headers update, handling empty dict to clear headers."""
|
||||
if not headers:
|
||||
return None
|
||||
|
||||
# Merge with existing headers to preserve masked values
|
||||
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
|
||||
return self._prepare_encrypted_dict(final_headers, tenant_id)
|
||||
|
||||
def _process_credentials(
|
||||
self, authentication: MCPAuthentication, mcp_provider: MCPToolProvider, tenant_id: str
|
||||
) -> str:
|
||||
"""Process credentials update, handling masked values."""
|
||||
# Merge with existing credentials
|
||||
final_client_id, final_client_secret = self._merge_credentials_with_masked(
|
||||
authentication.client_id, authentication.client_secret, mcp_provider
|
||||
)
|
||||
|
||||
# Build and encrypt
|
||||
return self._build_and_encrypt_credentials(final_client_id, final_client_secret, tenant_id)
|
||||
|
||||
def _merge_headers_with_masked(
|
||||
self, incoming_headers: dict[str, str], mcp_provider: MCPToolProvider
|
||||
) -> dict[str, str]:
|
||||
"""Merge incoming headers with existing ones, preserving unchanged masked values.
|
||||
|
||||
Args:
|
||||
incoming_headers: Headers from frontend (may contain masked values)
|
||||
mcp_provider: The MCP provider instance
|
||||
|
||||
Returns:
|
||||
Final headers dict with proper values (original for unchanged masked, new for changed)
|
||||
"""
|
||||
mcp_provider_entity = mcp_provider.to_entity()
|
||||
existing_decrypted = mcp_provider_entity.decrypt_headers()
|
||||
existing_masked = mcp_provider_entity.masked_headers()
|
||||
|
||||
return {
|
||||
key: (str(existing_decrypted[key]) if key in existing_masked and value == existing_masked[key] else value)
|
||||
for key, value in incoming_headers.items()
|
||||
if key in existing_decrypted or value != existing_masked.get(key)
|
||||
}
|
||||
|
||||
def _merge_credentials_with_masked(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str | None,
|
||||
mcp_provider: MCPToolProvider,
|
||||
) -> tuple[
|
||||
str,
|
||||
str | None,
|
||||
]:
|
||||
"""Merge incoming credentials with existing ones, preserving unchanged masked values.
|
||||
|
||||
Args:
|
||||
client_id: Client ID from frontend (may be masked)
|
||||
client_secret: Client secret from frontend (may be masked)
|
||||
mcp_provider: The MCP provider instance
|
||||
|
||||
Returns:
|
||||
Tuple of (final_client_id, final_client_secret)
|
||||
"""
|
||||
mcp_provider_entity = mcp_provider.to_entity()
|
||||
existing_decrypted = mcp_provider_entity.decrypt_credentials()
|
||||
existing_masked = mcp_provider_entity.masked_credentials()
|
||||
|
||||
# Check if client_id is masked and unchanged
|
||||
final_client_id = client_id
|
||||
if existing_masked.get("client_id") and client_id == existing_masked["client_id"]:
|
||||
# Use existing decrypted value
|
||||
final_client_id = existing_decrypted.get("client_id", client_id)
|
||||
|
||||
# Check if client_secret is masked and unchanged
|
||||
final_client_secret = client_secret
|
||||
if existing_masked.get("client_secret") and client_secret == existing_masked["client_secret"]:
|
||||
# Use existing decrypted value
|
||||
final_client_secret = existing_decrypted.get("client_secret", client_secret)
|
||||
|
||||
return final_client_id, final_client_secret
|
||||
|
||||
def _build_and_encrypt_credentials(self, client_id: str, client_secret: str | None, tenant_id: str) -> str:
|
||||
"""Build credentials and encrypt sensitive fields."""
|
||||
# Create a flat structure with all credential data
|
||||
credentials_data = {
|
||||
"client_id": client_id,
|
||||
"client_name": CLIENT_NAME,
|
||||
"is_dynamic_registration": False,
|
||||
}
|
||||
secret_fields = []
|
||||
if client_secret is not None:
|
||||
credentials_data["encrypted_client_secret"] = client_secret
|
||||
secret_fields = ["encrypted_client_secret"]
|
||||
client_info = self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)
|
||||
return json.dumps({"client_information": client_info})
|
||||
8
dify/api/services/tools/tool_labels_service.py
Normal file
8
dify/api/services/tools/tool_labels_service.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from core.tools.entities.tool_entities import ToolLabel
|
||||
from core.tools.entities.values import default_tool_labels
|
||||
|
||||
|
||||
class ToolLabelsService:
|
||||
@classmethod
|
||||
def list_tool_labels(cls) -> list[ToolLabel]:
|
||||
return default_tool_labels
|
||||
26
dify/api/services/tools/tools_manage_service.py
Normal file
26
dify/api/services/tools/tools_manage_service.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import logging
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolCommonService:
|
||||
@staticmethod
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral | None = None):
|
||||
"""
|
||||
list tool providers
|
||||
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ)
|
||||
|
||||
# add icon
|
||||
for provider in providers:
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
|
||||
|
||||
result = [provider.to_dict() for provider in providers]
|
||||
|
||||
return result
|
||||
475
dify/api/services/tools/tools_transform_service.py
Normal file
475
dify/api/services/tools/tools_transform_service.py
Normal file
@@ -0,0 +1,475 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import ValidationError
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginDatasourceProviderEntity
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolTransformService:
|
||||
@classmethod
|
||||
def get_tool_provider_icon_url(
|
||||
cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str]
|
||||
) -> str | Mapping[str, str]:
|
||||
"""
|
||||
get tool provider icon url
|
||||
"""
|
||||
url_prefix = (
|
||||
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
|
||||
)
|
||||
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
return str(url_prefix / "builtin" / provider_name / "icon")
|
||||
elif provider_type in {ToolProviderType.API, ToolProviderType.WORKFLOW}:
|
||||
try:
|
||||
if isinstance(icon, str):
|
||||
return json.loads(icon)
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
return icon
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
|
||||
"""
|
||||
repack provider
|
||||
|
||||
:param tenant_id: the tenant id
|
||||
:param provider: the provider dict
|
||||
"""
|
||||
if isinstance(provider, dict) and "icon" in provider:
|
||||
provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
|
||||
)
|
||||
elif isinstance(provider, ToolProviderApiEntity):
|
||||
if provider.plugin_id:
|
||||
if isinstance(provider.icon, str):
|
||||
provider.icon = PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=provider.icon)
|
||||
if isinstance(provider.icon_dark, str) and provider.icon_dark:
|
||||
provider.icon_dark = PluginService.get_plugin_icon_url(
|
||||
tenant_id=tenant_id, filename=provider.icon_dark
|
||||
)
|
||||
else:
|
||||
provider.icon = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
|
||||
)
|
||||
if provider.icon_dark:
|
||||
provider.icon_dark = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon_dark
|
||||
)
|
||||
elif isinstance(provider, PluginDatasourceProviderEntity):
|
||||
if provider.plugin_id:
|
||||
if isinstance(provider.declaration.identity.icon, str):
|
||||
provider.declaration.identity.icon = PluginService.get_plugin_icon_url(
|
||||
tenant_id=tenant_id, filename=provider.declaration.identity.icon
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def builtin_provider_to_user_provider(
|
||||
cls,
|
||||
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
|
||||
db_provider: BuiltinToolProvider | None,
|
||||
decrypt_credentials: bool = True,
|
||||
) -> ToolProviderApiEntity:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
result = ToolProviderApiEntity(
|
||||
id=provider_controller.entity.identity.name,
|
||||
author=provider_controller.entity.identity.author,
|
||||
name=provider_controller.entity.identity.name,
|
||||
description=provider_controller.entity.identity.description,
|
||||
icon=provider_controller.entity.identity.icon,
|
||||
icon_dark=provider_controller.entity.identity.icon_dark or "",
|
||||
label=provider_controller.entity.identity.label,
|
||||
type=ToolProviderType.BUILT_IN,
|
||||
masked_credentials={},
|
||||
is_team_authorization=False,
|
||||
plugin_id=None,
|
||||
tools=[],
|
||||
labels=provider_controller.tool_labels,
|
||||
)
|
||||
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
result.plugin_id = provider_controller.plugin_id
|
||||
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
|
||||
|
||||
# get credentials schema
|
||||
schema = {
|
||||
x.to_basic_provider_config().name: x
|
||||
for x in provider_controller.get_credentials_schema_by_type(
|
||||
CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY
|
||||
)
|
||||
}
|
||||
|
||||
masked_creds = {}
|
||||
for name in schema:
|
||||
masked_creds[name] = ""
|
||||
result.masked_credentials = masked_creds
|
||||
|
||||
# check if the provider need credentials
|
||||
if not provider_controller.need_credentials:
|
||||
result.is_team_authorization = True
|
||||
result.allow_delete = False
|
||||
elif db_provider:
|
||||
result.is_team_authorization = True
|
||||
|
||||
if decrypt_credentials:
|
||||
credentials = db_provider.credentials
|
||||
if not db_provider.tenant_id:
|
||||
raise ValueError(f"Required tenant_id is missing for BuiltinToolProvider with id {db_provider.id}")
|
||||
# init tool configuration
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(
|
||||
CredentialType.of(db_provider.credential_type)
|
||||
)
|
||||
],
|
||||
cache=ToolProviderCredentialsCache(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
provider=db_provider.provider,
|
||||
credential_id=db_provider.id,
|
||||
),
|
||||
)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = encrypter.decrypt(data=credentials)
|
||||
masked_credentials = encrypter.mask_plugin_credentials(data=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
result.original_credentials = decrypted_credentials
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def api_provider_to_controller(
|
||||
db_provider: ApiToolProvider,
|
||||
) -> ApiToolProviderController:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
# package tool provider controller
|
||||
auth_type = ApiProviderAuthType.NONE
|
||||
credentials_auth_type = db_provider.credentials.get("auth_type")
|
||||
if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility
|
||||
auth_type = ApiProviderAuthType.API_KEY_HEADER
|
||||
elif credentials_auth_type == "api_key_query":
|
||||
auth_type = ApiProviderAuthType.API_KEY_QUERY
|
||||
|
||||
controller = ApiToolProviderController.from_db(
|
||||
db_provider=db_provider,
|
||||
auth_type=auth_type,
|
||||
)
|
||||
|
||||
return controller
|
||||
|
||||
@staticmethod
|
||||
def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
|
||||
"""
|
||||
convert provider controller to provider
|
||||
"""
|
||||
return WorkflowToolProviderController.from_db(db_provider)
|
||||
|
||||
@staticmethod
|
||||
def workflow_provider_to_user_provider(
|
||||
provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
|
||||
):
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
return ToolProviderApiEntity(
|
||||
id=provider_controller.provider_id,
|
||||
author=provider_controller.entity.identity.author,
|
||||
name=provider_controller.entity.identity.name,
|
||||
description=provider_controller.entity.identity.description,
|
||||
icon=provider_controller.entity.identity.icon,
|
||||
icon_dark=provider_controller.entity.identity.icon_dark or "",
|
||||
label=provider_controller.entity.identity.label,
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
plugin_id=None,
|
||||
plugin_unique_identifier=None,
|
||||
tools=[],
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mcp_provider_to_user_provider(
|
||||
db_provider: MCPToolProvider,
|
||||
for_list: bool = False,
|
||||
user_name: str | None = None,
|
||||
include_sensitive: bool = True,
|
||||
) -> ToolProviderApiEntity:
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
|
||||
# Use provided user_name to avoid N+1 query, fallback to load_user() if not provided
|
||||
if user_name is None:
|
||||
user = db_provider.load_user()
|
||||
user_name = user.name if user else None
|
||||
|
||||
# Convert to entity and use its API response method
|
||||
provider_entity = db_provider.to_entity()
|
||||
|
||||
response = provider_entity.to_api_response(user_name=user_name, include_sensitive=include_sensitive)
|
||||
try:
|
||||
mcp_tools = [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
||||
except (ValidationError, json.JSONDecodeError):
|
||||
mcp_tools = []
|
||||
# Add additional fields specific to the transform
|
||||
response["id"] = db_provider.server_identifier if not for_list else db_provider.id
|
||||
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, mcp_tools, user_name=user_name)
|
||||
response["server_identifier"] = db_provider.server_identifier
|
||||
|
||||
# Convert configuration dict to MCPConfiguration object
|
||||
if "configuration" in response and isinstance(response["configuration"], dict):
|
||||
response["configuration"] = MCPConfiguration(
|
||||
timeout=float(response["configuration"]["timeout"]),
|
||||
sse_read_timeout=float(response["configuration"]["sse_read_timeout"]),
|
||||
)
|
||||
|
||||
return ToolProviderApiEntity(**response)
|
||||
|
||||
@staticmethod
|
||||
def mcp_tool_to_user_tool(
|
||||
mcp_provider: MCPToolProvider, tools: list[MCPTool], user_name: str | None = None
|
||||
) -> list[ToolApiEntity]:
|
||||
# Use provided user_name to avoid N+1 query, fallback to load_user() if not provided
|
||||
if user_name is None:
|
||||
user = mcp_provider.load_user()
|
||||
user_name = user.name if user else "Anonymous"
|
||||
|
||||
return [
|
||||
ToolApiEntity(
|
||||
author=user_name or "Anonymous",
|
||||
name=tool.name,
|
||||
label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
|
||||
description=I18nObject(en_US=tool.description or "", zh_Hans=tool.description or ""),
|
||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
|
||||
labels=[],
|
||||
output_schema=tool.outputSchema or {},
|
||||
)
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def api_provider_to_user_provider(
|
||||
cls,
|
||||
provider_controller: ApiToolProviderController,
|
||||
db_provider: ApiToolProvider,
|
||||
decrypt_credentials: bool = True,
|
||||
labels: list[str] | None = None,
|
||||
) -> ToolProviderApiEntity:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
username = "Anonymous"
|
||||
if db_provider.user is None:
|
||||
raise ValueError(f"user is None for api provider {db_provider.id}")
|
||||
try:
|
||||
user = db_provider.user
|
||||
if not user:
|
||||
raise ValueError("user not found")
|
||||
|
||||
username = user.name
|
||||
except Exception:
|
||||
logger.exception("failed to get user name for api provider %s", db_provider.id)
|
||||
# add provider into providers
|
||||
credentials = db_provider.credentials
|
||||
result = ToolProviderApiEntity(
|
||||
id=db_provider.id,
|
||||
author=username,
|
||||
name=db_provider.name,
|
||||
description=I18nObject(
|
||||
en_US=db_provider.description,
|
||||
zh_Hans=db_provider.description,
|
||||
),
|
||||
icon=db_provider.icon,
|
||||
label=I18nObject(
|
||||
en_US=db_provider.name,
|
||||
zh_Hans=db_provider.name,
|
||||
),
|
||||
type=ToolProviderType.API,
|
||||
plugin_id=None,
|
||||
plugin_unique_identifier=None,
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
tools=[],
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
if decrypt_credentials:
|
||||
# init tool configuration
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
controller=provider_controller,
|
||||
)
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = encrypter.decrypt(data=credentials)
|
||||
masked_credentials = encrypter.mask_plugin_credentials(data=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def convert_tool_entity_to_api_entity(
|
||||
tool: ApiToolBundle | WorkflowTool | Tool,
|
||||
tenant_id: str,
|
||||
labels: list[str] | None = None,
|
||||
) -> ToolApiEntity:
|
||||
"""
|
||||
convert tool to user tool
|
||||
"""
|
||||
if isinstance(tool, Tool):
|
||||
# fork tool runtime
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
credentials={},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
# get tool parameters
|
||||
base_parameters = tool.entity.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = tool.get_runtime_parameters()
|
||||
|
||||
# merge parameters using a functional approach to avoid type issues
|
||||
merged_parameters: list[ToolParameter] = []
|
||||
|
||||
# create a mapping of runtime parameters for quick lookup
|
||||
runtime_param_map = {(rp.name, rp.form): rp for rp in runtime_parameters}
|
||||
|
||||
# process base parameters, replacing with runtime versions if they exist
|
||||
for base_param in base_parameters:
|
||||
key = (base_param.name, base_param.form)
|
||||
if key in runtime_param_map:
|
||||
merged_parameters.append(runtime_param_map[key])
|
||||
else:
|
||||
merged_parameters.append(base_param)
|
||||
|
||||
# add any runtime parameters that weren't in base parameters
|
||||
for runtime_parameter in runtime_parameters:
|
||||
if runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# check if this parameter is already in merged_parameters
|
||||
already_exists = any(
|
||||
p.name == runtime_parameter.name and p.form == runtime_parameter.form for p in merged_parameters
|
||||
)
|
||||
if not already_exists:
|
||||
merged_parameters.append(runtime_parameter)
|
||||
|
||||
return ToolApiEntity(
|
||||
author=tool.entity.identity.author,
|
||||
name=tool.entity.identity.name,
|
||||
label=tool.entity.identity.label,
|
||||
description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
|
||||
output_schema=tool.entity.output_schema,
|
||||
parameters=merged_parameters,
|
||||
labels=labels or [],
|
||||
)
|
||||
else:
|
||||
assert tool.operation_id
|
||||
return ToolApiEntity(
|
||||
author=tool.author,
|
||||
name=tool.operation_id or "",
|
||||
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
|
||||
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
|
||||
parameters=tool.parameters,
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_builtin_provider_to_credential_entity(
|
||||
provider: BuiltinToolProvider, credentials: dict
|
||||
) -> ToolProviderCredentialApiEntity:
|
||||
return ToolProviderCredentialApiEntity(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
provider=provider.provider,
|
||||
credential_type=CredentialType.of(provider.credential_type),
|
||||
is_default=provider.is_default,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_mcp_schema_to_parameter(schema: dict[str, Any]) -> list["ToolParameter"]:
|
||||
"""
|
||||
Convert MCP JSON schema to tool parameters
|
||||
|
||||
:param schema: JSON schema dictionary
|
||||
:return: list of ToolParameter instances
|
||||
"""
|
||||
|
||||
def create_parameter(
|
||||
name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None
|
||||
) -> ToolParameter:
|
||||
"""Create a ToolParameter instance with given attributes"""
|
||||
input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
|
||||
return ToolParameter(
|
||||
name=name,
|
||||
llm_description=description,
|
||||
label=I18nObject(en_US=name),
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
required=required,
|
||||
type=ToolParameter.ToolParameterType(param_type),
|
||||
human_description=I18nObject(en_US=description),
|
||||
**input_schema_dict,
|
||||
)
|
||||
|
||||
def process_properties(
|
||||
props: dict[str, dict[str, Any]], required: list[str], prefix: str = ""
|
||||
) -> list[ToolParameter]:
|
||||
"""Process properties recursively"""
|
||||
TYPE_MAPPING = {"integer": "number", "float": "number"}
|
||||
COMPLEX_TYPES = ["array", "object"]
|
||||
|
||||
parameters = []
|
||||
for name, prop in props.items():
|
||||
current_description = prop.get("description", "")
|
||||
prop_type = prop.get("type", "string")
|
||||
|
||||
if isinstance(prop_type, list):
|
||||
prop_type = prop_type[0]
|
||||
if prop_type in TYPE_MAPPING:
|
||||
prop_type = TYPE_MAPPING[prop_type]
|
||||
input_schema = prop if prop_type in COMPLEX_TYPES else None
|
||||
parameters.append(
|
||||
create_parameter(name, current_description, prop_type, name in required, input_schema)
|
||||
)
|
||||
|
||||
return parameters
|
||||
|
||||
if schema.get("type") == "object" and "properties" in schema:
|
||||
return process_properties(schema["properties"], schema.get("required", []))
|
||||
return []
|
||||
340
dify/api/services/tools/workflow_tools_manage_service.py
Normal file
340
dify/api/services/tools/workflow_tools_manage_service.py
Normal file
@@ -0,0 +1,340 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class WorkflowToolManageService:
|
||||
"""
|
||||
Service class for managing workflow tools.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_workflow_tool(
|
||||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
workflow_app_id: str,
|
||||
name: str,
|
||||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
):
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
# name or app_id
|
||||
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
|
||||
|
||||
app: App | None = db.session.query(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f"App {workflow_app_id} not found")
|
||||
|
||||
workflow: Workflow | None = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_app_id}")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
workflow_tool_provider = WorkflowToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=workflow_app_id,
|
||||
name=name,
|
||||
label=label,
|
||||
icon=json.dumps(icon),
|
||||
description=description,
|
||||
parameter_configuration=json.dumps(parameters),
|
||||
privacy_policy=privacy_policy,
|
||||
version=workflow.version,
|
||||
)
|
||||
session.add(workflow_tool_provider)
|
||||
|
||||
try:
|
||||
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
if labels is not None:
|
||||
ToolLabelManager.update_tool_labels(
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def update_workflow_tool(
|
||||
cls,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
workflow_tool_id: str,
|
||||
name: str,
|
||||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Update a workflow tool.
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_tool_id: workflow tool id
|
||||
:param name: name
|
||||
:param label: label
|
||||
:param icon: icon
|
||||
:param description: description
|
||||
:param parameters: parameters
|
||||
:param privacy_policy: privacy policy
|
||||
:param labels: labels
|
||||
:return: the updated tool
|
||||
"""
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.name == name,
|
||||
WorkflowToolProvider.id != workflow_tool_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} already exists")
|
||||
|
||||
workflow_tool_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if workflow_tool_provider is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
app: App | None = (
|
||||
db.session.query(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
|
||||
|
||||
workflow: Workflow | None = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
|
||||
|
||||
workflow_tool_provider.name = name
|
||||
workflow_tool_provider.label = label
|
||||
workflow_tool_provider.icon = json.dumps(icon)
|
||||
workflow_tool_provider.description = description
|
||||
workflow_tool_provider.parameter_configuration = json.dumps(parameters)
|
||||
workflow_tool_provider.privacy_policy = privacy_policy
|
||||
workflow_tool_provider.version = workflow.version
|
||||
workflow_tool_provider.updated_at = datetime.now()
|
||||
|
||||
try:
|
||||
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if labels is not None:
|
||||
ToolLabelManager.update_tool_labels(
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
List workflow tools.
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tools = db.session.scalars(
|
||||
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
tools: list[WorkflowToolProviderController] = []
|
||||
for provider in db_tools:
|
||||
try:
|
||||
tools.append(ToolTransformService.workflow_provider_to_controller(provider))
|
||||
except Exception:
|
||||
# skip deleted tools
|
||||
pass
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)])
|
||||
|
||||
result = []
|
||||
|
||||
for tool in tools:
|
||||
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=tool, labels=labels.get(tool.provider_id, [])
|
||||
)
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_tool_provider)
|
||||
user_tool_provider.tools = [
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool.get_tools(tenant_id)[0],
|
||||
labels=labels.get(tool.provider_id, []),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
]
|
||||
result.append(user_tool_provider)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
|
||||
"""
|
||||
Delete a workflow tool.
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_tool_id: the workflow tool id
|
||||
"""
|
||||
db.session.query(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
|
||||
).delete()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_tool_id: the workflow tool id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
)
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
.first()
|
||||
)
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:db_tool: the database tool
|
||||
:return: the tool
|
||||
"""
|
||||
if db_tool is None:
|
||||
raise ValueError("Tool not found")
|
||||
|
||||
workflow_app: App | None = (
|
||||
db.session.query(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
|
||||
)
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
|
||||
workflow = workflow_app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
|
||||
if len(workflow_tools) == 0:
|
||||
raise ValueError(f"Tool {db_tool.id} not found")
|
||||
|
||||
return {
|
||||
"name": db_tool.name,
|
||||
"label": db_tool.label,
|
||||
"workflow_tool_id": db_tool.id,
|
||||
"workflow_app_id": db_tool.app_id,
|
||||
"icon": json.loads(db_tool.icon),
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool.get_tools(db_tool.tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool),
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
"synced": workflow.version == db_tool.version,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
List workflow tool provider tools.
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_tool_id: the workflow tool id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_tool is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
|
||||
if len(workflow_tools) == 0:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
return [
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool.get_tools(db_tool.tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
]
|
||||
Reference in New Issue
Block a user