dify
This commit is contained in:
115
dify/api/core/mcp/mcp_client.py
Normal file
115
dify/api/core/mcp/mcp_client.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, ExitStack
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from core.mcp.client.sse_client import sse_client
|
||||
from core.mcp.client.streamable_client import streamablehttp_client
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.session.client_session import ClientSession
|
||||
from core.mcp.types import CallToolResult, Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClient:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
|
||||
# Initialize session and client objects
|
||||
self._session: ClientSession | None = None
|
||||
self._exit_stack = ExitStack()
|
||||
self._initialized = False
|
||||
|
||||
def __enter__(self):
|
||||
self._initialize()
|
||||
self._initialized = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type | None, exc_value: BaseException | None, traceback: TracebackType | None):
|
||||
self.cleanup()
|
||||
|
||||
def _initialize(
|
||||
self,
|
||||
):
|
||||
"""Initialize the client with fallback to SSE if streamable connection fails"""
|
||||
connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
|
||||
"mcp": streamablehttp_client,
|
||||
"sse": sse_client,
|
||||
}
|
||||
|
||||
parsed_url = urlparse(self.server_url)
|
||||
path = parsed_url.path or ""
|
||||
method_name = path.rstrip("/").split("/")[-1] if path else ""
|
||||
if method_name in connection_methods:
|
||||
client_factory = connection_methods[method_name]
|
||||
self.connect_server(client_factory, method_name)
|
||||
else:
|
||||
try:
|
||||
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
|
||||
self.connect_server(sse_client, "sse")
|
||||
except MCPConnectionError:
|
||||
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
||||
self.connect_server(streamablehttp_client, "mcp")
|
||||
|
||||
def connect_server(self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str) -> None:
|
||||
"""
|
||||
Connect to the MCP server using streamable http or sse.
|
||||
Default to streamable http.
|
||||
Args:
|
||||
client_factory: The client factory to use(streamablehttp_client or sse_client).
|
||||
method_name: The method name to use(mcp or sse).
|
||||
"""
|
||||
streams_context = client_factory(
|
||||
url=self.server_url,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
# Use exit_stack to manage context managers properly
|
||||
if method_name == "mcp":
|
||||
read_stream, write_stream, _ = self._exit_stack.enter_context(streams_context)
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self._exit_stack.enter_context(streams_context)
|
||||
|
||||
session_context = ClientSession(*streams)
|
||||
self._session = self._exit_stack.enter_context(session_context)
|
||||
self._session.initialize()
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""List available tools from the MCP server"""
|
||||
if not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
response = self._session.list_tools()
|
||||
return response.tools
|
||||
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||
"""Call a tool"""
|
||||
if not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
return self._session.call_tool(tool_name, tool_args)
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
try:
|
||||
# ExitStack will handle proper cleanup of all managed context managers
|
||||
self._exit_stack.close()
|
||||
except Exception as e:
|
||||
logger.exception("Error during cleanup")
|
||||
raise ValueError(f"Error during cleanup: {e}")
|
||||
finally:
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
Reference in New Issue
Block a user