dify
This commit is contained in:
156
dify/api/core/tools/mcp_tool/provider.py
Normal file
156
dify/api/core/tools/mcp_tool/provider.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from typing import Any, Self
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.types import Tool as RemoteMCPTool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolProviderEntityWithPlugin,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class MCPToolProviderController(ToolProviderController):
|
||||
def __init__(
|
||||
self,
|
||||
entity: ToolProviderEntityWithPlugin,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
server_url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
):
|
||||
super().__init__(entity)
|
||||
self.entity: ToolProviderEntityWithPlugin = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.provider_id = provider_id
|
||||
self.server_url = server_url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.MCP
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: MCPToolProvider) -> Self:
|
||||
"""
|
||||
from db provider
|
||||
"""
|
||||
# Convert to entity first
|
||||
provider_entity = db_provider.to_entity()
|
||||
return cls.from_entity(provider_entity)
|
||||
|
||||
@classmethod
|
||||
def from_entity(cls, entity: MCPProviderEntity) -> Self:
|
||||
"""
|
||||
create a MCPToolProviderController from a MCPProviderEntity
|
||||
"""
|
||||
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
|
||||
|
||||
tools = [
|
||||
ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author="Anonymous", # Tool level author is not stored
|
||||
name=remote_mcp_tool.name,
|
||||
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
|
||||
provider=entity.provider_id,
|
||||
icon=entity.icon if isinstance(entity.icon, str) else "",
|
||||
),
|
||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(
|
||||
en_US=remote_mcp_tool.description or "", zh_Hans=remote_mcp_tool.description or ""
|
||||
),
|
||||
llm=remote_mcp_tool.description or "",
|
||||
),
|
||||
output_schema=remote_mcp_tool.outputSchema or {},
|
||||
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
|
||||
)
|
||||
for remote_mcp_tool in remote_mcp_tools
|
||||
]
|
||||
if not entity.icon:
|
||||
raise ValueError("Database provider icon is required")
|
||||
return cls(
|
||||
entity=ToolProviderEntityWithPlugin(
|
||||
identity=ToolProviderIdentity(
|
||||
author="Anonymous", # Provider level author is not stored in entity
|
||||
name=entity.name,
|
||||
label=I18nObject(en_US=entity.name, zh_Hans=entity.name),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
icon=entity.icon if isinstance(entity.icon, str) else "",
|
||||
),
|
||||
plugin_id=None,
|
||||
credentials_schema=[],
|
||||
tools=tools,
|
||||
),
|
||||
provider_id=entity.provider_id,
|
||||
tenant_id=entity.tenant_id,
|
||||
server_url=entity.server_url,
|
||||
headers=entity.headers,
|
||||
timeout=entity.timeout,
|
||||
sse_read_timeout=entity.sse_read_timeout,
|
||||
)
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_tool(self, tool_name: str) -> MCPTool:
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
tool_entity = next(
|
||||
(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
|
||||
)
|
||||
|
||||
if not tool_entity:
|
||||
raise ValueError(f"Tool with name {tool_name} not found")
|
||||
|
||||
return MCPTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[MCPTool]:
|
||||
"""
|
||||
get all tools
|
||||
"""
|
||||
return [
|
||||
MCPTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
for tool_entity in self.entity.tools
|
||||
]
|
||||
174
dify/api/core/tools/mcp_tool/tool.py
Normal file
174
dify/api/core/tools/mcp_tool/tool.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
def __init__(
|
||||
self,
|
||||
entity: ToolEntity,
|
||||
runtime: ToolRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
server_url: str,
|
||||
provider_id: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
):
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.server_url = server_url
|
||||
self.provider_id = provider_id
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.MCP
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
result = self.invoke_remote_mcp_tool(tool_parameters)
|
||||
# handle dify tool output
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
yield from self._process_text_content(content)
|
||||
elif isinstance(content, ImageContent):
|
||||
yield self._process_image_content(content)
|
||||
elif isinstance(content, AudioContent):
|
||||
yield self._process_audio_content(content)
|
||||
else:
|
||||
logger.warning("Unsupported content type=%s", type(content))
|
||||
|
||||
# handle MCP structured output
|
||||
if self.entity.output_schema and result.structuredContent:
|
||||
for k, v in result.structuredContent.items():
|
||||
yield self.create_variable_message(k, v)
|
||||
|
||||
def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process text content and yield appropriate messages."""
|
||||
# Check if content looks like JSON before attempting to parse
|
||||
text = content.text.strip()
|
||||
if text and text[0] in ("{", "[") and text[-1] in ("}", "]"):
|
||||
try:
|
||||
content_json = json.loads(text)
|
||||
yield from self._process_json_content(content_json)
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# If not JSON or parsing failed, treat as plain text
|
||||
yield self.create_text_message(content.text)
|
||||
|
||||
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process JSON content based on its type."""
|
||||
if isinstance(content_json, dict):
|
||||
yield self.create_json_message(content_json)
|
||||
elif isinstance(content_json, list):
|
||||
yield from self._process_json_list(content_json)
|
||||
else:
|
||||
# For primitive types (str, int, bool, etc.), convert to string
|
||||
yield self.create_text_message(str(content_json))
|
||||
|
||||
def _process_json_list(self, json_list: list) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process a list of JSON items."""
|
||||
if any(not isinstance(item, dict) for item in json_list):
|
||||
# If the list contains any non-dict item, treat the entire list as a text message.
|
||||
yield self.create_text_message(str(json_list))
|
||||
return
|
||||
|
||||
# Otherwise, process each dictionary as a separate JSON message.
|
||||
for item in json_list:
|
||||
yield self.create_json_message(item)
|
||||
|
||||
def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage:
|
||||
"""Process image content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
|
||||
"""Process audio content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
in mcp tool invoke, if the parameter is empty, it will be set to None
|
||||
"""
|
||||
return {
|
||||
key: value
|
||||
for key, value in parameter.items()
|
||||
if value is not None and not (isinstance(value, str) and value.strip() == "")
|
||||
}
|
||||
|
||||
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
|
||||
headers = self.headers.copy() if self.headers else {}
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
# Step 1: Load provider entity and credentials in a short-lived session
|
||||
# This minimizes database connection hold time
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
|
||||
# Decrypt and prepare all credentials before closing session
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_headers()
|
||||
|
||||
# Try to get existing token and add to headers
|
||||
if not headers:
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
|
||||
# Step 2: Session is now closed, perform network operations without holding database connection
|
||||
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
Reference in New Issue
Block a user