dify
This commit is contained in:
0
dify/api/core/mcp/__init__.py
Normal file
0
dify/api/core/mcp/__init__.py
Normal file
705
dify/api/core/mcp/auth/auth_flow.py
Normal file
705
dify/api/core/mcp/auth/auth_flow.py
Normal file
@@ -0,0 +1,705 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import urllib.parse
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from httpx import RequestError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
|
||||
from core.helper import ssrf_proxy
|
||||
from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState
|
||||
from core.mcp.error import MCPRefreshTokenError
|
||||
from core.mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthTokens,
|
||||
ProtectedResourceMetadata,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||
|
||||
|
||||
def generate_pkce_challenge() -> tuple[str, str]:
|
||||
"""Generate PKCE challenge and verifier."""
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
|
||||
code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
|
||||
|
||||
code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8")
|
||||
code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
|
||||
|
||||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
def build_protected_resource_metadata_discovery_urls(
|
||||
www_auth_resource_metadata_url: str | None, server_url: str
|
||||
) -> list[str]:
|
||||
"""
|
||||
Build a list of URLs to try for Protected Resource Metadata discovery.
|
||||
|
||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
||||
"""
|
||||
urls = []
|
||||
|
||||
# First priority: URL from WWW-Authenticate header
|
||||
if www_auth_resource_metadata_url:
|
||||
urls.append(www_auth_resource_metadata_url)
|
||||
|
||||
# Fallback: construct from server URL
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
|
||||
if fallback_url not in urls:
|
||||
urls.append(fallback_url)
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
|
||||
"""
|
||||
Build a list of URLs to try for OAuth Authorization Server Metadata discovery.
|
||||
|
||||
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
|
||||
|
||||
Per RFC 8414 section 3:
|
||||
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
|
||||
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
|
||||
|
||||
Example:
|
||||
- issuer: https://example.com/oauth
|
||||
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
|
||||
"""
|
||||
urls = []
|
||||
base_url = auth_server_url or server_url
|
||||
|
||||
parsed = urlparse(base_url)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/") # Remove trailing slash
|
||||
|
||||
# Try OpenID Connect discovery first (more common)
|
||||
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
|
||||
|
||||
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
|
||||
# Include the path component if present in the issuer URL
|
||||
if path:
|
||||
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
|
||||
else:
|
||||
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def discover_protected_resource_metadata(
|
||||
prm_url: str | None, server_url: str, protocol_version: str | None = None
|
||||
) -> ProtectedResourceMetadata | None:
|
||||
"""Discover OAuth 2.0 Protected Resource Metadata (RFC 9470)."""
|
||||
urls = build_protected_resource_metadata_discovery_urls(prm_url, server_url)
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
response = ssrf_proxy.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
return ProtectedResourceMetadata.model_validate(response.json())
|
||||
elif response.status_code == 404:
|
||||
continue # Try next URL
|
||||
except (RequestError, ValidationError):
|
||||
continue # Try next URL
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def discover_oauth_authorization_server_metadata(
|
||||
auth_server_url: str | None, server_url: str, protocol_version: str | None = None
|
||||
) -> OAuthMetadata | None:
|
||||
"""Discover OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
|
||||
urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url)
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
response = ssrf_proxy.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
elif response.status_code == 404:
|
||||
continue # Try next URL
|
||||
except (RequestError, ValidationError):
|
||||
continue # Try next URL
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_effective_scope(
|
||||
scope_from_www_auth: str | None,
|
||||
prm: ProtectedResourceMetadata | None,
|
||||
asm: OAuthMetadata | None,
|
||||
client_scope: str | None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Determine effective scope using priority-based selection strategy.
|
||||
|
||||
Priority order:
|
||||
1. WWW-Authenticate header scope (server explicit requirement)
|
||||
2. Protected Resource Metadata scopes
|
||||
3. OAuth Authorization Server Metadata scopes
|
||||
4. Client configured scope
|
||||
"""
|
||||
if scope_from_www_auth:
|
||||
return scope_from_www_auth
|
||||
|
||||
if prm and prm.scopes_supported:
|
||||
return " ".join(prm.scopes_supported)
|
||||
|
||||
if asm and asm.scopes_supported:
|
||||
return " ".join(asm.scopes_supported)
|
||||
|
||||
return client_scope
|
||||
|
||||
|
||||
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
|
||||
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
|
||||
# Generate a secure random state key
|
||||
state_key = secrets.token_urlsafe(32)
|
||||
|
||||
# Store the state data in Redis with expiration
|
||||
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
|
||||
redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
|
||||
|
||||
return state_key
|
||||
|
||||
|
||||
def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
||||
"""Retrieve and decode OAuth state data from Redis using the state key, then delete it."""
|
||||
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
|
||||
|
||||
# Get state data from Redis
|
||||
state_data = redis_client.get(redis_key)
|
||||
|
||||
if not state_data:
|
||||
raise ValueError("State parameter has expired or does not exist")
|
||||
|
||||
# Delete the state data from Redis immediately after retrieval to prevent reuse
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
try:
|
||||
# Parse and validate the state data
|
||||
oauth_state = OAuthCallbackState.model_validate_json(state_data)
|
||||
|
||||
return oauth_state
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid state parameter: {str(e)}")
|
||||
|
||||
|
||||
def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]:
|
||||
"""
|
||||
Handle the callback from the OAuth provider.
|
||||
|
||||
Returns:
|
||||
A tuple of (callback_state, tokens) that can be used by the caller to save data.
|
||||
"""
|
||||
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
||||
full_state_data = _retrieve_redis_state(state_key)
|
||||
|
||||
tokens = exchange_authorization(
|
||||
full_state_data.server_url,
|
||||
full_state_data.metadata,
|
||||
full_state_data.client_information,
|
||||
authorization_code,
|
||||
full_state_data.code_verifier,
|
||||
full_state_data.redirect_uri,
|
||||
)
|
||||
|
||||
return full_state_data, tokens
|
||||
|
||||
|
||||
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
"""Check if the server supports OAuth 2.0 Resource Discovery."""
|
||||
b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource"
|
||||
if b_query:
|
||||
url_for_resource_discovery += f"?{b_query}"
|
||||
if b_fragment:
|
||||
url_for_resource_discovery += f"#{b_fragment}"
|
||||
try:
|
||||
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
|
||||
response = ssrf_proxy.get(url_for_resource_discovery, headers=headers)
|
||||
if 200 <= response.status_code < 300:
|
||||
body = response.json()
|
||||
# Support both singular and plural forms
|
||||
if body.get("authorization_servers"):
|
||||
return True, body["authorization_servers"][0]
|
||||
elif body.get("authorization_server_url"):
|
||||
return True, body["authorization_server_url"][0]
|
||||
else:
|
||||
return False, ""
|
||||
return False, ""
|
||||
except RequestError:
|
||||
# Not support resource discovery, fall back to well-known OAuth metadata
|
||||
return False, ""
|
||||
|
||||
|
||||
def discover_oauth_metadata(
|
||||
server_url: str,
|
||||
resource_metadata_url: str | None = None,
|
||||
scope_hint: str | None = None,
|
||||
protocol_version: str | None = None,
|
||||
) -> tuple[OAuthMetadata | None, ProtectedResourceMetadata | None, str | None]:
|
||||
"""
|
||||
Discover OAuth metadata using RFC 8414/9470 standards.
|
||||
|
||||
Args:
|
||||
server_url: The MCP server URL
|
||||
resource_metadata_url: Protected Resource Metadata URL from WWW-Authenticate header
|
||||
scope_hint: Scope hint from WWW-Authenticate header
|
||||
protocol_version: MCP protocol version
|
||||
|
||||
Returns:
|
||||
(oauth_metadata, protected_resource_metadata, scope_hint)
|
||||
"""
|
||||
# Discover Protected Resource Metadata
|
||||
prm = discover_protected_resource_metadata(resource_metadata_url, server_url, protocol_version)
|
||||
|
||||
# Get authorization server URL from PRM or use server URL
|
||||
auth_server_url = None
|
||||
if prm and prm.authorization_servers:
|
||||
auth_server_url = prm.authorization_servers[0]
|
||||
|
||||
# Discover OAuth Authorization Server Metadata
|
||||
asm = discover_oauth_authorization_server_metadata(auth_server_url, server_url, protocol_version)
|
||||
|
||||
return asm, prm, scope_hint
|
||||
|
||||
|
||||
def start_authorization(
|
||||
server_url: str,
|
||||
metadata: OAuthMetadata | None,
|
||||
client_information: OAuthClientInformation,
|
||||
redirect_url: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
scope: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Begins the authorization flow with secure Redis state storage."""
|
||||
response_type = "code"
|
||||
code_challenge_method = "S256"
|
||||
|
||||
if metadata:
|
||||
authorization_url = metadata.authorization_endpoint
|
||||
if response_type not in metadata.response_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
|
||||
else:
|
||||
authorization_url = urljoin(server_url, "/authorize")
|
||||
|
||||
code_verifier, code_challenge = generate_pkce_challenge()
|
||||
|
||||
# Prepare state data with all necessary information
|
||||
state_data = OAuthCallbackState(
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
server_url=server_url,
|
||||
metadata=metadata,
|
||||
client_information=client_information,
|
||||
code_verifier=code_verifier,
|
||||
redirect_uri=redirect_url,
|
||||
)
|
||||
|
||||
# Store state data in Redis and generate secure state key
|
||||
state_key = _create_secure_redis_state(state_data)
|
||||
|
||||
params = {
|
||||
"response_type": response_type,
|
||||
"client_id": client_information.client_id,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": code_challenge_method,
|
||||
"redirect_uri": redirect_url,
|
||||
"state": state_key,
|
||||
}
|
||||
|
||||
# Add scope if provided
|
||||
if scope:
|
||||
params["scope"] = scope
|
||||
|
||||
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
|
||||
return authorization_url, code_verifier
|
||||
|
||||
|
||||
def _parse_token_response(response: httpx.Response) -> OAuthTokens:
|
||||
"""
|
||||
Parse OAuth token response supporting both JSON and form-urlencoded formats.
|
||||
|
||||
Per RFC 6749 Section 5.1, the standard format is JSON.
|
||||
However, some legacy OAuth providers (e.g., early GitHub OAuth Apps) return
|
||||
application/x-www-form-urlencoded format for backwards compatibility.
|
||||
|
||||
Args:
|
||||
response: The HTTP response from token endpoint
|
||||
|
||||
Returns:
|
||||
Parsed OAuth tokens
|
||||
|
||||
Raises:
|
||||
ValueError: If response cannot be parsed
|
||||
"""
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
|
||||
if "application/json" in content_type:
|
||||
# Standard OAuth 2.0 JSON response (RFC 6749)
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
elif "application/x-www-form-urlencoded" in content_type:
|
||||
# Legacy form-urlencoded response (non-standard but used by some providers)
|
||||
token_data = dict(urllib.parse.parse_qsl(response.text))
|
||||
return OAuthTokens.model_validate(token_data)
|
||||
else:
|
||||
# No content-type or unknown - try JSON first, fallback to form-urlencoded
|
||||
try:
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
except (ValidationError, json.JSONDecodeError):
|
||||
token_data = dict(urllib.parse.parse_qsl(response.text))
|
||||
return OAuthTokens.model_validate(token_data)
|
||||
|
||||
|
||||
def exchange_authorization(
|
||||
server_url: str,
|
||||
metadata: OAuthMetadata | None,
|
||||
client_information: OAuthClientInformation,
|
||||
authorization_code: str,
|
||||
code_verifier: str,
|
||||
redirect_uri: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchanges an authorization code for an access token."""
|
||||
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
|
||||
params = {
|
||||
"grant_type": grant_type,
|
||||
"client_id": client_information.client_id,
|
||||
"code": authorization_code,
|
||||
"code_verifier": code_verifier,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = ssrf_proxy.post(token_url, data=params)
|
||||
if not response.is_success:
|
||||
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
|
||||
return _parse_token_response(response)
|
||||
|
||||
|
||||
def refresh_authorization(
|
||||
server_url: str,
|
||||
metadata: OAuthMetadata | None,
|
||||
client_information: OAuthClientInformation,
|
||||
refresh_token: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange a refresh token for an updated access token."""
|
||||
grant_type = MCPSupportGrantType.REFRESH_TOKEN.value
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
|
||||
params = {
|
||||
"grant_type": grant_type,
|
||||
"client_id": client_information.client_id,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
try:
|
||||
response = ssrf_proxy.post(token_url, data=params)
|
||||
except ssrf_proxy.MaxRetriesExceededError as e:
|
||||
raise MCPRefreshTokenError(e) from e
|
||||
if not response.is_success:
|
||||
raise MCPRefreshTokenError(response.text)
|
||||
return _parse_token_response(response)
|
||||
|
||||
|
||||
def client_credentials_flow(
|
||||
server_url: str,
|
||||
metadata: OAuthMetadata | None,
|
||||
client_information: OAuthClientInformation,
|
||||
scope: str | None = None,
|
||||
) -> OAuthTokens:
|
||||
"""Execute Client Credentials Flow to get access token."""
|
||||
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
|
||||
# Support both Basic Auth and body parameters for client authentication
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
data = {"grant_type": grant_type}
|
||||
|
||||
if scope:
|
||||
data["scope"] = scope
|
||||
|
||||
# If client_secret is provided, use Basic Auth (preferred method)
|
||||
if client_information.client_secret:
|
||||
credentials = f"{client_information.client_id}:{client_information.client_secret}"
|
||||
encoded_credentials = base64.b64encode(credentials.encode()).decode()
|
||||
headers["Authorization"] = f"Basic {encoded_credentials}"
|
||||
else:
|
||||
# Fall back to including credentials in the body
|
||||
data["client_id"] = client_information.client_id
|
||||
if client_information.client_secret:
|
||||
data["client_secret"] = client_information.client_secret
|
||||
|
||||
response = ssrf_proxy.post(token_url, headers=headers, data=data)
|
||||
if not response.is_success:
|
||||
raise ValueError(
|
||||
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
|
||||
)
|
||||
|
||||
return _parse_token_response(response)
|
||||
|
||||
|
||||
def register_client(
|
||||
server_url: str,
|
||||
metadata: OAuthMetadata | None,
|
||||
client_metadata: OAuthClientMetadata,
|
||||
) -> OAuthClientInformationFull:
|
||||
"""Performs OAuth 2.0 Dynamic Client Registration."""
|
||||
if metadata:
|
||||
if not metadata.registration_endpoint:
|
||||
raise ValueError("Incompatible auth server: does not support dynamic client registration")
|
||||
registration_url = metadata.registration_endpoint
|
||||
else:
|
||||
registration_url = urljoin(server_url, "/register")
|
||||
|
||||
response = ssrf_proxy.post(
|
||||
registration_url,
|
||||
json=client_metadata.model_dump(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
if not response.is_success:
|
||||
response.raise_for_status()
|
||||
return OAuthClientInformationFull.model_validate(response.json())
|
||||
|
||||
|
||||
def auth(
|
||||
provider: MCPProviderEntity,
|
||||
authorization_code: str | None = None,
|
||||
state_param: str | None = None,
|
||||
resource_metadata_url: str | None = None,
|
||||
scope_hint: str | None = None,
|
||||
) -> AuthResult:
|
||||
"""
|
||||
Orchestrates the full auth flow with a server using secure Redis state storage.
|
||||
|
||||
This function performs only network operations and returns actions that need
|
||||
to be performed by the caller (such as saving data to database).
|
||||
|
||||
Args:
|
||||
provider: The MCP provider entity
|
||||
authorization_code: Optional authorization code from OAuth callback
|
||||
state_param: Optional state parameter from OAuth callback
|
||||
resource_metadata_url: Optional Protected Resource Metadata URL from WWW-Authenticate
|
||||
scope_hint: Optional scope hint from WWW-Authenticate header
|
||||
|
||||
Returns:
|
||||
AuthResult containing actions to be performed and response data
|
||||
"""
|
||||
actions: list[AuthAction] = []
|
||||
server_url = provider.decrypt_server_url()
|
||||
|
||||
# Discover OAuth metadata using RFC 8414/9470 standards
|
||||
server_metadata, prm, scope_from_www_auth = discover_oauth_metadata(
|
||||
server_url, resource_metadata_url, scope_hint, LATEST_PROTOCOL_VERSION
|
||||
)
|
||||
|
||||
client_metadata = provider.client_metadata
|
||||
provider_id = provider.id
|
||||
tenant_id = provider.tenant_id
|
||||
client_information = provider.retrieve_client_information()
|
||||
redirect_url = provider.redirect_url
|
||||
credentials = provider.decrypt_credentials()
|
||||
|
||||
# Determine grant type based on server metadata
|
||||
if not server_metadata:
|
||||
raise ValueError("Failed to discover OAuth metadata from server")
|
||||
|
||||
supported_grant_types = server_metadata.grant_types_supported or []
|
||||
|
||||
# Convert to lowercase for comparison
|
||||
supported_grant_types_lower = [gt.lower() for gt in supported_grant_types]
|
||||
|
||||
# Determine which grant type to use
|
||||
effective_grant_type = None
|
||||
if MCPSupportGrantType.AUTHORIZATION_CODE.value in supported_grant_types_lower:
|
||||
effective_grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
|
||||
else:
|
||||
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
# Determine effective scope using priority-based strategy
|
||||
effective_scope = get_effective_scope(scope_from_www_auth, prm, server_metadata, credentials.get("scope"))
|
||||
|
||||
if not client_information:
|
||||
if authorization_code is not None:
|
||||
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
|
||||
|
||||
# For client credentials flow, we don't need to register client dynamically
|
||||
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
|
||||
# Client should provide client_id and client_secret directly
|
||||
raise ValueError("Client credentials flow requires client_id and client_secret to be provided")
|
||||
|
||||
try:
|
||||
full_information = register_client(server_url, server_metadata, client_metadata)
|
||||
except RequestError as e:
|
||||
raise ValueError(f"Could not register OAuth client: {e}")
|
||||
|
||||
# Return action to save client information
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_CLIENT_INFO,
|
||||
data={"client_information": full_information.model_dump()},
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
client_information = full_information
|
||||
|
||||
# Handle client credentials flow
|
||||
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
|
||||
# Direct token request without user interaction
|
||||
try:
|
||||
tokens = client_credentials_flow(
|
||||
server_url,
|
||||
server_metadata,
|
||||
client_information,
|
||||
effective_scope,
|
||||
)
|
||||
|
||||
# Return action to save tokens and grant type
|
||||
token_data = tokens.model_dump()
|
||||
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=token_data,
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
except (RequestError, ValueError, KeyError) as e:
|
||||
# RequestError: HTTP request failed
|
||||
# ValueError: Invalid response data
|
||||
# KeyError: Missing required fields in response
|
||||
raise ValueError(f"Client credentials flow failed: {e}")
|
||||
|
||||
# Exchange authorization code for tokens (Authorization Code flow)
|
||||
if authorization_code is not None:
|
||||
if not state_param:
|
||||
raise ValueError("State parameter is required when exchanging authorization code")
|
||||
|
||||
try:
|
||||
# Retrieve state data from Redis using state key
|
||||
full_state_data = _retrieve_redis_state(state_param)
|
||||
|
||||
code_verifier = full_state_data.code_verifier
|
||||
redirect_uri = full_state_data.redirect_uri
|
||||
|
||||
if not code_verifier or not redirect_uri:
|
||||
raise ValueError("Missing code_verifier or redirect_uri in state data")
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise ValueError(f"Invalid state parameter: {e}")
|
||||
|
||||
tokens = exchange_authorization(
|
||||
server_url,
|
||||
server_metadata,
|
||||
client_information,
|
||||
authorization_code,
|
||||
code_verifier,
|
||||
redirect_uri,
|
||||
)
|
||||
|
||||
# Return action to save tokens
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=tokens.model_dump(),
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
|
||||
provider_tokens = provider.retrieve_tokens()
|
||||
|
||||
# Handle token refresh or new authorization
|
||||
if provider_tokens and provider_tokens.refresh_token:
|
||||
try:
|
||||
new_tokens = refresh_authorization(
|
||||
server_url, server_metadata, client_information, provider_tokens.refresh_token
|
||||
)
|
||||
|
||||
# Return action to save new tokens
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=new_tokens.model_dump(),
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
except (RequestError, ValueError, KeyError) as e:
|
||||
# RequestError: HTTP request failed
|
||||
# ValueError: Invalid response data
|
||||
# KeyError: Missing required fields in response
|
||||
raise ValueError(f"Could not refresh OAuth tokens: {e}")
|
||||
|
||||
# Start new authorization flow (only for authorization code flow)
|
||||
authorization_url, code_verifier = start_authorization(
|
||||
server_url,
|
||||
server_metadata,
|
||||
client_information,
|
||||
redirect_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
effective_scope,
|
||||
)
|
||||
|
||||
# Return action to save code verifier
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_CODE_VERIFIER,
|
||||
data={"code_verifier": code_verifier},
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"authorization_url": authorization_url})
|
||||
197
dify/api/core/mcp/auth_client.py
Normal file
197
dify/api/core/mcp/auth_client.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
MCP Client with Authentication Retry Support
|
||||
|
||||
This module provides an enhanced MCPClient that automatically handles
|
||||
authentication failures and retries operations after refreshing tokens.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.error import MCPAuthError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import CallToolResult, Tool
|
||||
from extensions.ext_database import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClientWithAuthRetry(MCPClient):
|
||||
"""
|
||||
An enhanced MCPClient that provides automatic authentication retry.
|
||||
|
||||
This class extends MCPClient and intercepts MCPAuthError exceptions
|
||||
to refresh authentication before retrying failed operations.
|
||||
|
||||
Note: This class uses lazy session creation - database sessions are only
|
||||
created when authentication retry is actually needed, not on every request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
provider_entity: MCPProviderEntity | None = None,
|
||||
authorization_code: str | None = None,
|
||||
by_server_id: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the MCP client with auth retry capability.
|
||||
|
||||
Args:
|
||||
server_url: The MCP server URL
|
||||
headers: Optional headers for requests
|
||||
timeout: Request timeout
|
||||
sse_read_timeout: SSE read timeout
|
||||
provider_entity: Provider entity for authentication
|
||||
authorization_code: Optional authorization code for initial auth
|
||||
by_server_id: Whether to look up provider by server ID
|
||||
"""
|
||||
super().__init__(server_url, headers, timeout, sse_read_timeout)
|
||||
|
||||
self.provider_entity = provider_entity
|
||||
self.authorization_code = authorization_code
|
||||
self.by_server_id = by_server_id
|
||||
self._has_retried = False
|
||||
|
||||
def _handle_auth_error(self, error: MCPAuthError) -> None:
|
||||
"""
|
||||
Handle authentication error by refreshing tokens.
|
||||
|
||||
This method creates a short-lived database session only when authentication
|
||||
retry is needed, minimizing database connection hold time.
|
||||
|
||||
Args:
|
||||
error: The authentication error
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails or max retries reached
|
||||
"""
|
||||
if not self.provider_entity:
|
||||
raise error
|
||||
if self._has_retried:
|
||||
raise error
|
||||
|
||||
self._has_retried = True
|
||||
|
||||
try:
|
||||
# Create a temporary session only for auth retry
|
||||
# This session is short-lived and only exists during the auth operation
|
||||
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
with Session(db.engine) as session, session.begin():
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
|
||||
# Perform authentication using the service's auth method
|
||||
# Extract OAuth metadata hints from the error
|
||||
mcp_service.auth_with_actions(
|
||||
self.provider_entity,
|
||||
self.authorization_code,
|
||||
resource_metadata_url=error.resource_metadata_url,
|
||||
scope_hint=error.scope_hint,
|
||||
)
|
||||
|
||||
# Retrieve new tokens
|
||||
self.provider_entity = mcp_service.get_provider_entity(
|
||||
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
|
||||
)
|
||||
|
||||
# Session is closed here, before we update headers
|
||||
token = self.provider_entity.retrieve_tokens()
|
||||
if not token:
|
||||
raise MCPAuthError("Authentication failed - no token received")
|
||||
|
||||
# Update headers with new token
|
||||
self.headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||
|
||||
# Clear authorization code after first use
|
||||
self.authorization_code = None
|
||||
|
||||
except MCPAuthError:
|
||||
# Re-raise MCPAuthError as is
|
||||
raise
|
||||
except Exception as e:
|
||||
# Catch all exceptions during auth retry
|
||||
logger.exception("Authentication retry failed")
|
||||
raise MCPAuthError(f"Authentication retry failed: {e}") from e
|
||||
|
||||
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
|
||||
"""
|
||||
Execute a function with authentication retry logic.
|
||||
|
||||
Args:
|
||||
func: The function to execute
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
The result of the function call
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails after retries
|
||||
Any other exceptions from the function
|
||||
"""
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except MCPAuthError as e:
|
||||
self._handle_auth_error(e)
|
||||
|
||||
# Re-initialize the connection with new headers
|
||||
if self._initialized:
|
||||
# Clean up existing connection
|
||||
self._exit_stack.close()
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
|
||||
# Re-initialize with new headers
|
||||
self._initialize()
|
||||
self._initialized = True
|
||||
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
# Reset retry flag after operation completes
|
||||
self._has_retried = False
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the context manager with retry support."""
|
||||
|
||||
def initialize_with_retry():
|
||||
super(MCPClientWithAuthRetry, self).__enter__()
|
||||
return self
|
||||
|
||||
return self._execute_with_retry(initialize_with_retry)
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""
|
||||
List available tools from the MCP server with auth retry.
|
||||
|
||||
Returns:
|
||||
List of available tools
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails after retries
|
||||
"""
|
||||
return self._execute_with_retry(super().list_tools)
|
||||
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||
"""
|
||||
Invoke a tool on the MCP server with auth retry.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to invoke
|
||||
tool_args: Arguments for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool invocation
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails after retries
|
||||
"""
|
||||
return self._execute_with_retry(super().invoke_tool, tool_name, tool_args)
|
||||
0
dify/api/core/mcp/auth_client_comparison.md
Normal file
0
dify/api/core/mcp/auth_client_comparison.md
Normal file
360
dify/api/core/mcp/client/sse_client.py
Normal file
360
dify/api/core/mcp/client/sse_client.py
Normal file
@@ -0,0 +1,360 @@
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, TypeAlias, final
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from httpx_sse import EventSource, ServerSentEvent
|
||||
from sseclient import SSEClient
|
||||
|
||||
from core.mcp import types
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.types import SessionMessage
|
||||
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||
|
||||
|
||||
@final
|
||||
class _StatusReady:
|
||||
def __init__(self, endpoint_url: str):
|
||||
self.endpoint_url = endpoint_url
|
||||
|
||||
|
||||
@final
|
||||
class _StatusError:
|
||||
def __init__(self, exc: Exception):
|
||||
self.exc = exc
|
||||
|
||||
|
||||
# Type aliases for better readability
|
||||
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
||||
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
||||
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
|
||||
|
||||
|
||||
class SSETransport:
|
||||
"""SSE client transport implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 1 * 60,
|
||||
):
|
||||
"""Initialize the SSE transport.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.endpoint_url: str | None = None
|
||||
|
||||
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
|
||||
"""Validate that the endpoint URL matches the connection origin.
|
||||
|
||||
Args:
|
||||
endpoint_url: The endpoint URL to validate.
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise.
|
||||
"""
|
||||
url_parsed = urlparse(self.url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
|
||||
return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
|
||||
|
||||
def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue):
|
||||
"""Handle an 'endpoint' SSE event.
|
||||
|
||||
Args:
|
||||
sse_data: The SSE event data.
|
||||
status_queue: Queue to put status updates.
|
||||
"""
|
||||
endpoint_url = urljoin(self.url, sse_data)
|
||||
logger.info("Received endpoint URL: %s", endpoint_url)
|
||||
|
||||
if not self._validate_endpoint_url(endpoint_url):
|
||||
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||
logger.error(error_msg)
|
||||
status_queue.put(_StatusError(ValueError(error_msg)))
|
||||
return
|
||||
|
||||
status_queue.put(_StatusReady(endpoint_url))
|
||||
|
||||
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue):
|
||||
"""Handle a 'message' SSE event.
|
||||
|
||||
Args:
|
||||
sse_data: The SSE event data.
|
||||
read_queue: Queue to put parsed messages.
|
||||
"""
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(sse_data)
|
||||
logger.debug("Received server message: %s", message)
|
||||
session_message = SessionMessage(message)
|
||||
read_queue.put(session_message)
|
||||
except Exception as exc:
|
||||
logger.exception("Error parsing server message")
|
||||
read_queue.put(exc)
|
||||
|
||||
def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue):
|
||||
"""Handle a single SSE event.
|
||||
|
||||
Args:
|
||||
sse: The SSE event object.
|
||||
read_queue: Queue for message events.
|
||||
status_queue: Queue for status events.
|
||||
"""
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
self._handle_endpoint_event(sse.data, status_queue)
|
||||
case "message":
|
||||
self._handle_message_event(sse.data, read_queue)
|
||||
case _:
|
||||
logger.warning("Unknown SSE event: %s", sse.event)
|
||||
|
||||
def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue):
|
||||
"""Read and process SSE events.
|
||||
|
||||
Args:
|
||||
event_source: The SSE event source.
|
||||
read_queue: Queue to put received messages.
|
||||
status_queue: Queue to put status updates.
|
||||
"""
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
self._handle_sse_event(sse, read_queue, status_queue)
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug("SSE reader shutting down normally: %s", exc)
|
||||
except Exception as exc:
|
||||
read_queue.put(exc)
|
||||
finally:
|
||||
read_queue.put(None)
|
||||
|
||||
def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage):
|
||||
"""Send a single message to the server.
|
||||
|
||||
Args:
|
||||
client: HTTP client to use.
|
||||
endpoint_url: The endpoint URL to send to.
|
||||
message: The message to send.
|
||||
"""
|
||||
response = client.post(
|
||||
endpoint_url,
|
||||
json=message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug("Client message sent successfully: %s", response.status_code)
|
||||
|
||||
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue):
|
||||
"""Handle writing messages to the server.
|
||||
|
||||
Args:
|
||||
client: HTTP client to use.
|
||||
endpoint_url: The endpoint URL to send messages to.
|
||||
write_queue: Queue to read messages from.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, Exception):
|
||||
write_queue.put(message)
|
||||
continue
|
||||
|
||||
self._send_message(client, endpoint_url, message)
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug("Post writer shutting down normally: %s", exc)
|
||||
except Exception as exc:
|
||||
logger.exception("Error writing messages")
|
||||
write_queue.put(exc)
|
||||
finally:
|
||||
write_queue.put(None)
|
||||
|
||||
def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
|
||||
"""Wait for the endpoint URL from the status queue.
|
||||
|
||||
Args:
|
||||
status_queue: Queue to read status from.
|
||||
|
||||
Returns:
|
||||
The endpoint URL.
|
||||
|
||||
Raises:
|
||||
ValueError: If endpoint URL is not received or there's an error.
|
||||
"""
|
||||
try:
|
||||
status = status_queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
if isinstance(status, _StatusReady):
|
||||
return status.endpoint_url
|
||||
elif isinstance(status, _StatusError):
|
||||
raise status.exc
|
||||
else:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
def connect(
|
||||
self,
|
||||
executor: ThreadPoolExecutor,
|
||||
client: httpx.Client,
|
||||
event_source: EventSource,
|
||||
) -> tuple[ReadQueue, WriteQueue]:
|
||||
"""Establish connection and start worker threads.
|
||||
|
||||
Args:
|
||||
executor: Thread pool executor.
|
||||
client: HTTP client.
|
||||
event_source: SSE event source.
|
||||
|
||||
Returns:
|
||||
Tuple of (read_queue, write_queue).
|
||||
"""
|
||||
read_queue: ReadQueue = queue.Queue()
|
||||
write_queue: WriteQueue = queue.Queue()
|
||||
status_queue: StatusQueue = queue.Queue()
|
||||
|
||||
# Start SSE reader thread
|
||||
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
|
||||
|
||||
# Wait for endpoint URL
|
||||
endpoint_url = self._wait_for_endpoint(status_queue)
|
||||
self.endpoint_url = endpoint_url
|
||||
|
||||
# Start post writer thread
|
||||
executor.submit(self.post_writer, client, endpoint_url, write_queue)
|
||||
|
||||
return read_queue, write_queue
|
||||
|
||||
|
||||
@contextmanager
|
||||
def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 1 * 60,
|
||||
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
|
||||
"""
|
||||
Client transport for SSE.
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
|
||||
Yields:
|
||||
Tuple of (read_queue, write_queue) for message communication.
|
||||
"""
|
||||
transport = SSETransport(url, headers, timeout, sse_read_timeout)
|
||||
|
||||
read_queue: ReadQueue | None = None
|
||||
write_queue: WriteQueue | None = None
|
||||
|
||||
executor = ThreadPoolExecutor()
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
||||
with ssrf_proxy_sse_connect(
|
||||
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
|
||||
read_queue, write_queue = transport.connect(executor, client, event_source)
|
||||
|
||||
yield read_queue, write_queue
|
||||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError(response=exc.response)
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise
|
||||
finally:
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
read_queue.put(None)
|
||||
if write_queue:
|
||||
write_queue.put(None)
|
||||
|
||||
# Shutdown executor without waiting to prevent hanging
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
|
||||
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):
|
||||
"""
|
||||
Send a message to the server using the provided HTTP client.
|
||||
|
||||
Args:
|
||||
http_client: The HTTP client to use for sending
|
||||
endpoint_url: The endpoint URL to send the message to
|
||||
session_message: The message to send
|
||||
"""
|
||||
try:
|
||||
response = http_client.post(
|
||||
endpoint_url,
|
||||
json=session_message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug("Client message sent successfully: %s", response.status_code)
|
||||
except Exception:
|
||||
logger.exception("Error sending message")
|
||||
raise
|
||||
|
||||
|
||||
def read_messages(
|
||||
sse_client: SSEClient,
|
||||
) -> Generator[SessionMessage | Exception, None, None]:
|
||||
"""
|
||||
Read messages from the SSE client.
|
||||
|
||||
Args:
|
||||
sse_client: The SSE client to read from
|
||||
|
||||
Yields:
|
||||
SessionMessage or Exception for each event received
|
||||
"""
|
||||
try:
|
||||
for sse in sse_client.events():
|
||||
if sse.event == "message":
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug("Received server message: %s", message)
|
||||
yield SessionMessage(message)
|
||||
except Exception as exc:
|
||||
logger.exception("Error parsing server message")
|
||||
yield exc
|
||||
else:
|
||||
logger.warning("Unknown SSE event: %s", sse.event)
|
||||
except Exception as exc:
|
||||
logger.exception("Error reading SSE messages")
|
||||
yield exc
|
||||
485
dify/api/core/mcp/client/streamable_client.py
Normal file
485
dify/api/core/mcp/client/streamable_client.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""
|
||||
StreamableHTTP Client Transport Module
|
||||
|
||||
This module implements the StreamableHTTP transport for MCP clients,
|
||||
providing support for HTTP POST requests with optional SSE streaming responses
|
||||
and session management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Callable, Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
from httpx_sse import EventSource, ServerSentEvent
|
||||
|
||||
from core.mcp.types import (
|
||||
ClientMessageMetadata,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
RequestId,
|
||||
SessionMessage,
|
||||
)
|
||||
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SessionMessageOrError = SessionMessage | Exception | None
|
||||
# Queue types with clearer names for their roles
|
||||
ServerToClientQueue = queue.Queue[SessionMessageOrError] # Server to client messages
|
||||
ClientToServerQueue = queue.Queue[SessionMessage | None] # Client to server messages
|
||||
GetSessionIdCallback = Callable[[], str | None]
|
||||
|
||||
MCP_SESSION_ID = "mcp-session-id"
|
||||
LAST_EVENT_ID = "last-event-id"
|
||||
CONTENT_TYPE = "content-type"
|
||||
ACCEPT = "Accept"
|
||||
|
||||
|
||||
JSON = "application/json"
|
||||
SSE = "text/event-stream"
|
||||
|
||||
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||
|
||||
|
||||
class StreamableHTTPError(Exception):
|
||||
"""Base exception for StreamableHTTP transport errors."""
|
||||
|
||||
|
||||
class ResumptionError(StreamableHTTPError):
|
||||
"""Raised when resumption request is invalid."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
"""Context for a request operation."""
|
||||
|
||||
client: httpx.Client
|
||||
headers: dict[str, str]
|
||||
session_id: str | None
|
||||
session_message: SessionMessage
|
||||
metadata: ClientMessageMetadata | None
|
||||
server_to_client_queue: ServerToClientQueue # Renamed for clarity
|
||||
sse_read_timeout: float
|
||||
|
||||
|
||||
class StreamableHTTPTransport:
|
||||
"""StreamableHTTP client transport implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float | timedelta = 30,
|
||||
sse_read_timeout: float | timedelta = 60 * 5,
|
||||
):
|
||||
"""Initialize the StreamableHTTP transport.
|
||||
|
||||
Args:
|
||||
url: The endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
|
||||
self.sse_read_timeout = (
|
||||
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
|
||||
)
|
||||
self.session_id: str | None = None
|
||||
self.request_headers = {
|
||||
ACCEPT: f"{JSON}, {SSE}",
|
||||
CONTENT_TYPE: JSON,
|
||||
**self.headers,
|
||||
}
|
||||
|
||||
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Update headers with session ID if available."""
|
||||
headers = base_headers.copy()
|
||||
if self.session_id:
|
||||
headers[MCP_SESSION_ID] = self.session_id
|
||||
return headers
|
||||
|
||||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialization request."""
|
||||
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
||||
|
||||
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialized notification."""
|
||||
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
|
||||
|
||||
def _maybe_extract_session_id_from_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
):
|
||||
"""Extract and store session ID from response headers."""
|
||||
new_session_id = response.headers.get(MCP_SESSION_ID)
|
||||
if new_session_id:
|
||||
self.session_id = new_session_id
|
||||
logger.info("Received session ID: %s", self.session_id)
|
||||
|
||||
def _handle_sse_event(
|
||||
self,
|
||||
sse: ServerSentEvent,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
original_request_id: RequestId | None = None,
|
||||
resumption_callback: Callable[[str], None] | None = None,
|
||||
) -> bool:
|
||||
"""Handle an SSE event, returning True if the response is complete."""
|
||||
if sse.event == "message":
|
||||
# ping event send by server will be recognized as a message event with empty data by httpx-sse's SSEDecoder
|
||||
if not sse.data.strip():
|
||||
return False
|
||||
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug("SSE message: %s", message)
|
||||
|
||||
# If this is a response and we have original_request_id, replace it
|
||||
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
||||
message.root.id = original_request_id
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
# Put message in queue that goes to client
|
||||
server_to_client_queue.put(session_message)
|
||||
|
||||
# Call resumption token callback if we have an ID
|
||||
if sse.id and resumption_callback:
|
||||
resumption_callback(sse.id)
|
||||
|
||||
# If this is a response or error return True indicating completion
|
||||
# Otherwise, return False to continue listening
|
||||
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
|
||||
|
||||
except Exception as exc:
|
||||
# Put exception in queue that goes to client
|
||||
server_to_client_queue.put(exc)
|
||||
return False
|
||||
elif sse.event == "ping":
|
||||
logger.debug("Received ping event")
|
||||
return False
|
||||
else:
|
||||
logger.warning("Unknown SSE event: %s", sse.event)
|
||||
return False
|
||||
|
||||
def handle_get_stream(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
):
|
||||
"""Handle GET stream for server-initiated messages."""
|
||||
try:
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
headers = self._update_headers_with_session(self.request_headers)
|
||||
|
||||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
|
||||
client=client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("GET SSE connection established")
|
||||
|
||||
for sse in event_source.iter_sse():
|
||||
self._handle_sse_event(sse, server_to_client_queue)
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug("GET stream error (non-fatal): %s", exc)
|
||||
|
||||
def _handle_resumption_request(self, ctx: RequestContext):
|
||||
"""Handle a resumption request using GET with SSE."""
|
||||
headers = self._update_headers_with_session(ctx.headers)
|
||||
if ctx.metadata and ctx.metadata.resumption_token:
|
||||
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
|
||||
else:
|
||||
raise ResumptionError("Resumption request requires a resumption token")
|
||||
|
||||
# Extract original request ID to map responses
|
||||
original_request_id = None
|
||||
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
|
||||
original_request_id = ctx.session_message.message.root.id
|
||||
|
||||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
|
||||
client=ctx.client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("Resumption GET SSE connection established")
|
||||
|
||||
for sse in event_source.iter_sse():
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
original_request_id,
|
||||
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
|
||||
def _handle_post_request(self, ctx: RequestContext):
|
||||
"""Handle a POST request with response processing."""
|
||||
headers = self._update_headers_with_session(ctx.headers)
|
||||
message = ctx.session_message.message
|
||||
is_initialization = self._is_initialization_request(message)
|
||||
|
||||
with ctx.client.stream(
|
||||
"POST",
|
||||
self.url,
|
||||
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status_code == 202:
|
||||
logger.debug("Received 202 Accepted")
|
||||
return
|
||||
|
||||
if response.status_code == 204:
|
||||
logger.debug("Received 204 No Content")
|
||||
return
|
||||
|
||||
if response.status_code == 404:
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
self._send_session_terminated_error(
|
||||
ctx.server_to_client_queue,
|
||||
message.root.id,
|
||||
)
|
||||
return
|
||||
|
||||
response.raise_for_status()
|
||||
if is_initialization:
|
||||
self._maybe_extract_session_id_from_response(response)
|
||||
|
||||
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
|
||||
|
||||
if content_type.startswith(JSON):
|
||||
self._handle_json_response(response, ctx.server_to_client_queue)
|
||||
elif content_type.startswith(SSE):
|
||||
self._handle_sse_response(response, ctx)
|
||||
else:
|
||||
self._handle_unexpected_content_type(
|
||||
content_type,
|
||||
ctx.server_to_client_queue,
|
||||
)
|
||||
|
||||
def _handle_json_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
):
|
||||
"""Handle JSON response from the server."""
|
||||
try:
|
||||
content = response.read()
|
||||
message = JSONRPCMessage.model_validate_json(content)
|
||||
session_message = SessionMessage(message)
|
||||
server_to_client_queue.put(session_message)
|
||||
except Exception as exc:
|
||||
server_to_client_queue.put(exc)
|
||||
|
||||
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
|
||||
"""Handle SSE response from the server."""
|
||||
try:
|
||||
event_source = EventSource(response)
|
||||
for sse in event_source.iter_sse():
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
except Exception as e:
|
||||
ctx.server_to_client_queue.put(e)
|
||||
|
||||
def _handle_unexpected_content_type(
|
||||
self,
|
||||
content_type: str,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
):
|
||||
"""Handle unexpected content type in response."""
|
||||
error_msg = f"Unexpected content type: {content_type}"
|
||||
logger.error(error_msg)
|
||||
server_to_client_queue.put(ValueError(error_msg))
|
||||
|
||||
def _send_session_terminated_error(
|
||||
self,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
request_id: RequestId,
|
||||
):
|
||||
"""Send a session terminated error response."""
|
||||
jsonrpc_error = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
error=ErrorData(code=32600, message="Session terminated by server"),
|
||||
)
|
||||
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
|
||||
server_to_client_queue.put(session_message)
|
||||
|
||||
def post_writer(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
client_to_server_queue: ClientToServerQueue,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
start_get_stream: Callable[[], None],
|
||||
):
|
||||
"""Handle writing requests to the server.
|
||||
|
||||
This method processes messages from the client_to_server_queue and sends them to the server.
|
||||
Responses are written to the server_to_client_queue.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# Read message from client queue with timeout to check stop_event periodically
|
||||
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if session_message is None:
|
||||
break
|
||||
|
||||
message = session_message.message
|
||||
metadata = (
|
||||
session_message.metadata if isinstance(session_message.metadata, ClientMessageMetadata) else None
|
||||
)
|
||||
|
||||
# Check if this is a resumption request
|
||||
is_resumption = bool(metadata and metadata.resumption_token)
|
||||
|
||||
logger.debug("Sending client message: %s", message)
|
||||
|
||||
# Handle initialized notification
|
||||
if self._is_initialized_notification(message):
|
||||
start_get_stream()
|
||||
|
||||
ctx = RequestContext(
|
||||
client=client,
|
||||
headers=self.request_headers,
|
||||
session_id=self.session_id,
|
||||
session_message=session_message,
|
||||
metadata=metadata,
|
||||
server_to_client_queue=server_to_client_queue, # Queue to write responses to client
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
if is_resumption:
|
||||
self._handle_resumption_request(ctx)
|
||||
else:
|
||||
self._handle_post_request(ctx)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as exc:
|
||||
server_to_client_queue.put(exc)
|
||||
|
||||
def terminate_session(self, client: httpx.Client):
|
||||
"""Terminate the session by sending a DELETE request."""
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
try:
|
||||
headers = self._update_headers_with_session(self.request_headers)
|
||||
response = client.delete(self.url, headers=headers)
|
||||
|
||||
if response.status_code == 405:
|
||||
logger.debug("Server does not allow session termination")
|
||||
elif response.status_code != 200:
|
||||
logger.warning("Session termination failed: %s", response.status_code)
|
||||
except Exception as exc:
|
||||
logger.warning("Session termination failed: %s", exc)
|
||||
|
||||
def get_session_id(self) -> str | None:
|
||||
"""Get the current session ID."""
|
||||
return self.session_id
|
||||
|
||||
|
||||
@contextmanager
|
||||
def streamablehttp_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float | timedelta = 30,
|
||||
sse_read_timeout: float | timedelta = 60 * 5,
|
||||
terminate_on_close: bool = True,
|
||||
) -> Generator[
|
||||
tuple[
|
||||
ServerToClientQueue, # Queue for receiving messages FROM server
|
||||
ClientToServerQueue, # Queue for sending messages TO server
|
||||
GetSessionIdCallback,
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Client transport for StreamableHTTP.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
|
||||
Yields:
|
||||
Tuple containing:
|
||||
- server_to_client_queue: Queue for reading messages FROM the server
|
||||
- client_to_server_queue: Queue for sending messages TO the server
|
||||
- get_session_id_callback: Function to retrieve the current session ID
|
||||
"""
|
||||
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
|
||||
|
||||
# Create queues with clear directional meaning
|
||||
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
|
||||
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=2)
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||
) as client:
|
||||
# Define callbacks that need access to thread pool
|
||||
def start_get_stream():
|
||||
"""Start a worker thread to handle server-initiated messages."""
|
||||
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
||||
|
||||
# Start the post_writer worker thread
|
||||
executor.submit(
|
||||
transport.post_writer,
|
||||
client,
|
||||
client_to_server_queue, # Queue for messages FROM client TO server
|
||||
server_to_client_queue, # Queue for messages FROM server TO client
|
||||
start_get_stream,
|
||||
)
|
||||
|
||||
try:
|
||||
yield (
|
||||
server_to_client_queue, # Queue for receiving messages FROM server
|
||||
client_to_server_queue, # Queue for sending messages TO server
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
transport.terminate_session(client)
|
||||
|
||||
# Signal threads to stop
|
||||
client_to_server_queue.put(None)
|
||||
finally:
|
||||
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
||||
try:
|
||||
while not client_to_server_queue.empty():
|
||||
client_to_server_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
client_to_server_queue.put(None)
|
||||
server_to_client_queue.put(None)
|
||||
|
||||
# Shutdown executor without waiting to prevent hanging
|
||||
executor.shutdown(wait=False)
|
||||
60
dify/api/core/mcp/entities.py
Normal file
60
dify/api/core/mcp/entities.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
||||
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
LifespanContextT = TypeVar("LifespanContextT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext(Generic[SessionT, LifespanContextT]):
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: LifespanContextT
|
||||
|
||||
|
||||
class AuthActionType(StrEnum):
|
||||
"""Types of actions that can be performed during auth flow."""
|
||||
|
||||
SAVE_CLIENT_INFO = "save_client_info"
|
||||
SAVE_TOKENS = "save_tokens"
|
||||
SAVE_CODE_VERIFIER = "save_code_verifier"
|
||||
START_AUTHORIZATION = "start_authorization"
|
||||
SUCCESS = "success"
|
||||
|
||||
|
||||
class AuthAction(BaseModel):
|
||||
"""Represents an action that needs to be performed as a result of auth flow."""
|
||||
|
||||
action_type: AuthActionType
|
||||
data: dict[str, Any]
|
||||
provider_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class AuthResult(BaseModel):
|
||||
"""Result of auth function containing actions to be performed and response data."""
|
||||
|
||||
actions: list[AuthAction]
|
||||
response: dict[str, str]
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
"""State data stored in Redis during OAuth callback flow."""
|
||||
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
server_url: str
|
||||
metadata: OAuthMetadata | None = None
|
||||
client_information: OAuthClientInformation
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
63
dify/api/core/mcp/error.py
Normal file
63
dify/api/core/mcp/error.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import httpx
|
||||
|
||||
|
||||
class MCPError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MCPConnectionError(MCPError):
|
||||
pass
|
||||
|
||||
|
||||
class MCPAuthError(MCPConnectionError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str | None = None,
|
||||
response: "httpx.Response | None" = None,
|
||||
www_authenticate_header: str | None = None,
|
||||
):
|
||||
"""
|
||||
MCP Authentication Error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
response: HTTP response object (will extract WWW-Authenticate header if provided)
|
||||
www_authenticate_header: Pre-extracted WWW-Authenticate header value
|
||||
"""
|
||||
super().__init__(message or "Authentication failed")
|
||||
|
||||
# Extract OAuth metadata hints from WWW-Authenticate header
|
||||
if response is not None:
|
||||
www_authenticate_header = response.headers.get("WWW-Authenticate")
|
||||
|
||||
self.resource_metadata_url: str | None = None
|
||||
self.scope_hint: str | None = None
|
||||
|
||||
if www_authenticate_header:
|
||||
self.resource_metadata_url = self._extract_field(www_authenticate_header, "resource_metadata")
|
||||
self.scope_hint = self._extract_field(www_authenticate_header, "scope")
|
||||
|
||||
@staticmethod
|
||||
def _extract_field(www_auth: str, field_name: str) -> str | None:
|
||||
"""Extract a specific field from the WWW-Authenticate header."""
|
||||
# Pattern to match field="value" or field=value
|
||||
pattern = rf'{field_name}="([^"]*)"'
|
||||
match = re.search(pattern, www_auth)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# Try without quotes
|
||||
pattern = rf"{field_name}=([^\s,]+)"
|
||||
match = re.search(pattern, www_auth)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MCPRefreshTokenError(MCPError):
|
||||
pass
|
||||
115
dify/api/core/mcp/mcp_client.py
Normal file
115
dify/api/core/mcp/mcp_client.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, ExitStack
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from core.mcp.client.sse_client import sse_client
|
||||
from core.mcp.client.streamable_client import streamablehttp_client
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.session.client_session import ClientSession
|
||||
from core.mcp.types import CallToolResult, Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClient:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
|
||||
# Initialize session and client objects
|
||||
self._session: ClientSession | None = None
|
||||
self._exit_stack = ExitStack()
|
||||
self._initialized = False
|
||||
|
||||
def __enter__(self):
|
||||
self._initialize()
|
||||
self._initialized = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type | None, exc_value: BaseException | None, traceback: TracebackType | None):
|
||||
self.cleanup()
|
||||
|
||||
def _initialize(
|
||||
self,
|
||||
):
|
||||
"""Initialize the client with fallback to SSE if streamable connection fails"""
|
||||
connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
|
||||
"mcp": streamablehttp_client,
|
||||
"sse": sse_client,
|
||||
}
|
||||
|
||||
parsed_url = urlparse(self.server_url)
|
||||
path = parsed_url.path or ""
|
||||
method_name = path.rstrip("/").split("/")[-1] if path else ""
|
||||
if method_name in connection_methods:
|
||||
client_factory = connection_methods[method_name]
|
||||
self.connect_server(client_factory, method_name)
|
||||
else:
|
||||
try:
|
||||
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
|
||||
self.connect_server(sse_client, "sse")
|
||||
except MCPConnectionError:
|
||||
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
||||
self.connect_server(streamablehttp_client, "mcp")
|
||||
|
||||
def connect_server(self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str) -> None:
|
||||
"""
|
||||
Connect to the MCP server using streamable http or sse.
|
||||
Default to streamable http.
|
||||
Args:
|
||||
client_factory: The client factory to use(streamablehttp_client or sse_client).
|
||||
method_name: The method name to use(mcp or sse).
|
||||
"""
|
||||
streams_context = client_factory(
|
||||
url=self.server_url,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
# Use exit_stack to manage context managers properly
|
||||
if method_name == "mcp":
|
||||
read_stream, write_stream, _ = self._exit_stack.enter_context(streams_context)
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self._exit_stack.enter_context(streams_context)
|
||||
|
||||
session_context = ClientSession(*streams)
|
||||
self._session = self._exit_stack.enter_context(session_context)
|
||||
self._session.initialize()
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""List available tools from the MCP server"""
|
||||
if not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
response = self._session.list_tools()
|
||||
return response.tools
|
||||
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||
"""Call a tool"""
|
||||
if not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
return self._session.call_tool(tool_name, tool_args)
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
try:
|
||||
# ExitStack will handle proper cleanup of all managed context managers
|
||||
self._exit_stack.close()
|
||||
except Exception as e:
|
||||
logger.exception("Error during cleanup")
|
||||
raise ValueError(f"Error during cleanup: {e}")
|
||||
finally:
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
262
dify/api/core/mcp/server/streamable_http.py
Normal file
262
dify/api/core/mcp/server/streamable_http.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
from core.mcp import types as mcp_types
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def handle_mcp_request(
|
||||
app: App,
|
||||
request: mcp_types.ClientRequest,
|
||||
user_input_form: list[VariableEntity],
|
||||
mcp_server: AppMCPServer,
|
||||
end_user: EndUser | None = None,
|
||||
request_id: int | str = 1,
|
||||
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError:
|
||||
"""
|
||||
Handle MCP request and return JSON-RPC response
|
||||
|
||||
Args:
|
||||
app: The Dify app instance
|
||||
request: The JSON-RPC request message
|
||||
user_input_form: List of variable entities for the app
|
||||
mcp_server: The MCP server configuration
|
||||
end_user: Optional end user
|
||||
request_id: The request ID
|
||||
|
||||
Returns:
|
||||
JSON-RPC response or error
|
||||
"""
|
||||
|
||||
request_type = type(request.root)
|
||||
request_root = request.root
|
||||
|
||||
def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse:
|
||||
"""Create success response with business result data"""
|
||||
return mcp_types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
result=result_data.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
def create_error_response(code: int, message: str) -> mcp_types.JSONRPCError:
|
||||
"""Create error response with error code and message"""
|
||||
from core.mcp.types import ErrorData
|
||||
|
||||
error_data = ErrorData(code=code, message=message)
|
||||
return mcp_types.JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
error=error_data,
|
||||
)
|
||||
|
||||
try:
|
||||
# Dispatch request to appropriate handler based on instance type
|
||||
if isinstance(request_root, mcp_types.InitializeRequest):
|
||||
return create_success_response(handle_initialize(mcp_server.description))
|
||||
elif isinstance(request_root, mcp_types.ListToolsRequest):
|
||||
return create_success_response(
|
||||
handle_list_tools(
|
||||
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
|
||||
)
|
||||
)
|
||||
elif isinstance(request_root, mcp_types.CallToolRequest):
|
||||
return create_success_response(handle_call_tool(app, request, user_input_form, end_user))
|
||||
elif isinstance(request_root, mcp_types.PingRequest):
|
||||
return create_success_response(handle_ping())
|
||||
else:
|
||||
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Invalid params")
|
||||
return create_error_response(mcp_types.INVALID_PARAMS, str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Internal server error")
|
||||
return create_error_response(mcp_types.INTERNAL_ERROR, "Internal server error: " + str(e))
|
||||
|
||||
|
||||
def handle_ping() -> mcp_types.EmptyResult:
|
||||
"""Handle ping request"""
|
||||
return mcp_types.EmptyResult()
|
||||
|
||||
|
||||
def handle_initialize(description: str) -> mcp_types.InitializeResult:
|
||||
"""Handle initialize request"""
|
||||
capabilities = mcp_types.ServerCapabilities(
|
||||
tools=mcp_types.ToolsCapability(listChanged=False),
|
||||
)
|
||||
|
||||
return mcp_types.InitializeResult(
|
||||
protocolVersion=mcp_types.SERVER_LATEST_PROTOCOL_VERSION,
|
||||
capabilities=capabilities,
|
||||
serverInfo=mcp_types.Implementation(name="Dify", version=dify_config.project.version),
|
||||
instructions=description,
|
||||
)
|
||||
|
||||
|
||||
def handle_list_tools(
|
||||
app_name: str,
|
||||
app_mode: str,
|
||||
user_input_form: list[VariableEntity],
|
||||
description: str,
|
||||
parameters_dict: dict[str, str],
|
||||
) -> mcp_types.ListToolsResult:
|
||||
"""Handle list tools request"""
|
||||
parameter_schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)
|
||||
|
||||
return mcp_types.ListToolsResult(
|
||||
tools=[
|
||||
mcp_types.Tool(
|
||||
name=app_name,
|
||||
description=description,
|
||||
inputSchema=parameter_schema,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def handle_call_tool(
|
||||
app: App,
|
||||
request: mcp_types.ClientRequest,
|
||||
user_input_form: list[VariableEntity],
|
||||
end_user: EndUser | None,
|
||||
) -> mcp_types.CallToolResult:
|
||||
"""Handle call tool request"""
|
||||
request_obj = cast(mcp_types.CallToolRequest, request.root)
|
||||
args = prepare_tool_arguments(app, request_obj.params.arguments or {})
|
||||
|
||||
if not end_user:
|
||||
raise ValueError("End user not found")
|
||||
|
||||
response = AppGenerateService.generate(
|
||||
app,
|
||||
end_user,
|
||||
args,
|
||||
InvokeFrom.SERVICE_API,
|
||||
streaming=app.mode == AppMode.AGENT_CHAT,
|
||||
)
|
||||
|
||||
answer = extract_answer_from_response(app, response)
|
||||
return mcp_types.CallToolResult(content=[mcp_types.TextContent(text=answer, type="text")])
|
||||
|
||||
|
||||
def build_parameter_schema(
|
||||
app_mode: str,
|
||||
user_input_form: list[VariableEntity],
|
||||
parameters_dict: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
"""Build parameter schema for the tool"""
|
||||
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||
|
||||
if app_mode in {AppMode.COMPLETION, AppMode.WORKFLOW}:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
}
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "User Input/Question content"},
|
||||
**parameters,
|
||||
},
|
||||
"required": ["query", *required],
|
||||
}
|
||||
|
||||
|
||||
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Prepare arguments based on app mode"""
|
||||
if app.mode == AppMode.WORKFLOW:
|
||||
return {"inputs": arguments}
|
||||
elif app.mode == AppMode.COMPLETION:
|
||||
return {"query": "", "inputs": arguments}
|
||||
else:
|
||||
# Chat modes - create a copy to avoid modifying original dict
|
||||
args_copy = arguments.copy()
|
||||
query = args_copy.pop("query", "")
|
||||
return {"query": query, "inputs": args_copy}
|
||||
|
||||
|
||||
def extract_answer_from_response(app: App, response: Any) -> str:
|
||||
"""Extract answer from app generate response"""
|
||||
answer = ""
|
||||
|
||||
if isinstance(response, RateLimitGenerator):
|
||||
answer = process_streaming_response(response)
|
||||
elif isinstance(response, Mapping):
|
||||
answer = process_mapping_response(app, response)
|
||||
else:
|
||||
logger.warning("Unexpected response type: %s", type(response))
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
def process_streaming_response(response: RateLimitGenerator) -> str:
|
||||
"""Process streaming response for agent chat mode"""
|
||||
answer = ""
|
||||
for item in response.generator:
|
||||
if isinstance(item, str) and item.startswith("data: "):
|
||||
try:
|
||||
json_str = item[6:].strip()
|
||||
parsed_data = json.loads(json_str)
|
||||
if parsed_data.get("event") == "agent_thought":
|
||||
answer += parsed_data.get("thought", "")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return answer
|
||||
|
||||
|
||||
def process_mapping_response(app: App, response: Mapping) -> str:
|
||||
"""Process mapping response based on app mode"""
|
||||
if app.mode in {
|
||||
AppMode.ADVANCED_CHAT,
|
||||
AppMode.COMPLETION,
|
||||
AppMode.CHAT,
|
||||
AppMode.AGENT_CHAT,
|
||||
}:
|
||||
return response.get("answer", "")
|
||||
elif app.mode == AppMode.WORKFLOW:
|
||||
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("Invalid app mode: " + str(app.mode))
|
||||
|
||||
|
||||
def convert_input_form_to_parameters(
|
||||
user_input_form: list[VariableEntity],
|
||||
parameters_dict: dict[str, str],
|
||||
) -> tuple[dict[str, dict[str, Any]], list[str]]:
|
||||
"""Convert user input form to parameter schema"""
|
||||
parameters: dict[str, dict[str, Any]] = {}
|
||||
required = []
|
||||
|
||||
for item in user_input_form:
|
||||
if item.type in (
|
||||
VariableEntityType.FILE,
|
||||
VariableEntityType.FILE_LIST,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
):
|
||||
continue
|
||||
parameters[item.variable] = {}
|
||||
if item.required:
|
||||
required.append(item.variable)
|
||||
# if the workflow republished, the parameters not changed
|
||||
# we should not raise error here
|
||||
description = parameters_dict.get(item.variable, "")
|
||||
parameters[item.variable]["description"] = description
|
||||
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
parameters[item.variable]["type"] = "string"
|
||||
elif item.type == VariableEntityType.SELECT:
|
||||
parameters[item.variable]["type"] = "string"
|
||||
parameters[item.variable]["enum"] = item.options
|
||||
elif item.type == VariableEntityType.NUMBER:
|
||||
parameters[item.variable]["type"] = "number"
|
||||
return parameters, required
|
||||
423
dify/api/core/mcp/session/base_session.py
Normal file
423
dify/api/core/mcp/session/base_session.py
Normal file
@@ -0,0 +1,423 @@
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic, Self, TypeVar
|
||||
|
||||
from httpx import HTTPStatusError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.types import (
|
||||
CancelledNotification,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
MessageMetadata,
|
||||
RequestId,
|
||||
RequestParams,
|
||||
ServerMessageMetadata,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
SessionMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
|
||||
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
||||
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
||||
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
||||
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
||||
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
|
||||
DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
|
||||
|
||||
|
||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
"""Handles responding to MCP requests and manages request lifecycle.
|
||||
|
||||
This class MUST be used as a context manager to ensure proper cleanup and
|
||||
cancellation handling:
|
||||
|
||||
Example:
|
||||
with request_responder as resp:
|
||||
resp.respond(result)
|
||||
|
||||
The context manager ensures:
|
||||
1. Proper cancellation scope setup and cleanup
|
||||
2. Request completion tracking
|
||||
3. Cleanup of in-flight requests
|
||||
"""
|
||||
|
||||
request: ReceiveRequestT
|
||||
_session: Any
|
||||
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: RequestId,
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: """BaseSession[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT
|
||||
]""",
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.request_meta = request_meta
|
||||
self.request = request
|
||||
self._session = session
|
||||
self.completed = False
|
||||
self._on_complete = on_complete
|
||||
self._entered = False # Track if we're in a context manager
|
||||
|
||||
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
|
||||
"""Enter the context manager, enabling request cancellation tracking."""
|
||||
self._entered = True
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
):
|
||||
"""Exit the context manager, performing cleanup and notifying completion."""
|
||||
try:
|
||||
if self.completed:
|
||||
self._on_complete(self)
|
||||
finally:
|
||||
self._entered = False
|
||||
|
||||
def respond(self, response: SendResultT | ErrorData):
|
||||
"""Send a response for this request.
|
||||
|
||||
Must be called within a context manager block.
|
||||
Raises:
|
||||
RuntimeError: If not used within a context manager
|
||||
AssertionError: If request was already responded to
|
||||
"""
|
||||
if not self._entered:
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
assert not self.completed, "Request already responded to"
|
||||
|
||||
self.completed = True
|
||||
|
||||
self._session._send_response(request_id=self.request_id, response=response)
|
||||
|
||||
def cancel(self):
|
||||
"""Cancel this request and mark it as completed."""
|
||||
if not self._entered:
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
|
||||
self.completed = True # Mark as completed so it's removed from in_flight
|
||||
# Send an error response to indicate cancellation
|
||||
self._session._send_response(
|
||||
request_id=self.request_id,
|
||||
response=ErrorData(code=0, message="Request cancelled", data=None),
|
||||
)
|
||||
|
||||
|
||||
class BaseSession(
|
||||
Generic[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT,
|
||||
],
|
||||
):
|
||||
"""
|
||||
Implements an MCP "session" on top of read/write streams, including features
|
||||
like request/response linking, notifications, and progress.
|
||||
|
||||
This class is a context manager that automatically starts processing
|
||||
messages when entered.
|
||||
"""
|
||||
|
||||
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError]]
|
||||
_request_id: int
|
||||
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||
_receive_request_type: type[ReceiveRequestT]
|
||||
_receive_notification_type: type[ReceiveNotificationT]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: queue.Queue,
|
||||
write_stream: queue.Queue,
|
||||
receive_request_type: type[ReceiveRequestT],
|
||||
receive_notification_type: type[ReceiveNotificationT],
|
||||
# If none, reading will never time out
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
):
|
||||
self._read_stream = read_stream
|
||||
self._write_stream = write_stream
|
||||
self._response_streams = {}
|
||||
self._request_id = 0
|
||||
self._receive_request_type = receive_request_type
|
||||
self._receive_notification_type = receive_notification_type
|
||||
self._session_read_timeout_seconds = read_timeout_seconds
|
||||
self._in_flight = {}
|
||||
# Initialize executor and future to None for proper cleanup checks
|
||||
self._executor: ThreadPoolExecutor | None = None
|
||||
self._receiver_future: Future | None = None
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
# The thread pool is dedicated to running `_receive_loop`. Setting `max_workers` to 1
|
||||
# ensures no unnecessary threads are created.
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
self._receiver_future = self._executor.submit(self._receive_loop)
|
||||
return self
|
||||
|
||||
def check_receiver_status(self):
|
||||
"""`check_receiver_status` ensures that any exceptions raised during the
|
||||
execution of `_receive_loop` are retrieved and propagated."""
|
||||
if self._receiver_future and self._receiver_future.done():
|
||||
self._receiver_future.result()
|
||||
|
||||
def __exit__(
|
||||
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
|
||||
):
|
||||
self._read_stream.put(None)
|
||||
self._write_stream.put(None)
|
||||
|
||||
# Wait for the receiver loop to finish
|
||||
if self._receiver_future:
|
||||
try:
|
||||
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
|
||||
except TimeoutError:
|
||||
# If the receiver loop is still running after timeout, we'll force shutdown
|
||||
# Cancel the future to interrupt the receiver loop
|
||||
self._receiver_future.cancel()
|
||||
|
||||
# Shutdown the executor
|
||||
if self._executor:
|
||||
# Use non-blocking shutdown to prevent hanging
|
||||
# The receiver thread should have already exited due to the None message in the queue
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
def send_request(
|
||||
self,
|
||||
request: SendRequestT,
|
||||
result_type: type[ReceiveResultT],
|
||||
request_read_timeout_seconds: timedelta | None = None,
|
||||
metadata: MessageMetadata | None = None,
|
||||
) -> ReceiveResultT:
|
||||
"""
|
||||
Sends a request and wait for a response. Raises an McpError if the
|
||||
response contains an error. If a request read timeout is provided, it
|
||||
will take precedence over the session read timeout.
|
||||
|
||||
Do not use this method to emit notifications! Use send_notification()
|
||||
instead.
|
||||
"""
|
||||
self.check_receiver_status()
|
||||
|
||||
request_id = self._request_id
|
||||
self._request_id = request_id + 1
|
||||
|
||||
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError] = queue.Queue()
|
||||
self._response_streams[request_id] = response_queue
|
||||
|
||||
try:
|
||||
jsonrpc_request = JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
|
||||
timeout = DEFAULT_RESPONSE_READ_TIMEOUT
|
||||
if request_read_timeout_seconds is not None:
|
||||
timeout = float(request_read_timeout_seconds.total_seconds())
|
||||
elif self._session_read_timeout_seconds is not None:
|
||||
timeout = float(self._session_read_timeout_seconds.total_seconds())
|
||||
while True:
|
||||
try:
|
||||
response_or_error = response_queue.get(timeout=timeout)
|
||||
break
|
||||
except queue.Empty:
|
||||
self.check_receiver_status()
|
||||
continue
|
||||
|
||||
if response_or_error is None:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(
|
||||
code=500,
|
||||
message="No response received",
|
||||
)
|
||||
)
|
||||
elif isinstance(response_or_error, HTTPStatusError):
|
||||
# HTTPStatusError from streamable_client with preserved response object
|
||||
if response_or_error.response.status_code == 401:
|
||||
raise MCPAuthError(response=response_or_error.response)
|
||||
else:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(code=response_or_error.response.status_code, message=str(response_or_error))
|
||||
)
|
||||
elif isinstance(response_or_error, JSONRPCError):
|
||||
if response_or_error.error.code == 401:
|
||||
raise MCPAuthError(message=response_or_error.error.message)
|
||||
else:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
||||
)
|
||||
else:
|
||||
return result_type.model_validate(response_or_error.result)
|
||||
|
||||
finally:
|
||||
self._response_streams.pop(request_id, None)
|
||||
|
||||
def send_notification(
|
||||
self,
|
||||
notification: SendNotificationT,
|
||||
related_request_id: RequestId | None = None,
|
||||
):
|
||||
"""
|
||||
Emits a notification, which is a one-way message that does not expect
|
||||
a response.
|
||||
"""
|
||||
self.check_receiver_status()
|
||||
|
||||
# Some transport implementations may need to set the related_request_id
|
||||
# to attribute to the notifications to the request that triggered them.
|
||||
jsonrpc_notification = JSONRPCNotification(
|
||||
jsonrpc="2.0",
|
||||
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
session_message = SessionMessage(
|
||||
message=JSONRPCMessage(jsonrpc_notification),
|
||||
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
|
||||
)
|
||||
self._write_stream.put(session_message)
|
||||
|
||||
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
|
||||
if isinstance(response, ErrorData):
|
||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
||||
self._write_stream.put(session_message)
|
||||
else:
|
||||
jsonrpc_response = JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
|
||||
self._write_stream.put(session_message)
|
||||
|
||||
def _receive_loop(self):
|
||||
"""
|
||||
Main message processing loop.
|
||||
In a real synchronous implementation, this would likely run in a separate thread.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# Attempt to receive a message (this would be blocking in a synchronous context)
|
||||
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, HTTPStatusError):
|
||||
response_queue = self._response_streams.get(self._request_id - 1)
|
||||
if response_queue is not None:
|
||||
# For 401 errors, pass the HTTPStatusError directly to preserve response object
|
||||
if message.response.status_code == 401:
|
||||
response_queue.put(message)
|
||||
else:
|
||||
response_queue.put(
|
||||
JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=self._request_id - 1,
|
||||
error=ErrorData(code=message.response.status_code, message=message.args[0]),
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
|
||||
elif isinstance(message, Exception):
|
||||
self._handle_incoming(message)
|
||||
elif isinstance(message.message.root, JSONRPCRequest):
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
|
||||
responder = RequestResponder(
|
||||
request_id=message.message.root.id,
|
||||
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
|
||||
request=validated_request,
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
)
|
||||
|
||||
self._in_flight[responder.request_id] = responder
|
||||
self._received_request(responder)
|
||||
|
||||
if not responder.completed:
|
||||
self._handle_incoming(responder)
|
||||
|
||||
elif isinstance(message.message.root, JSONRPCNotification):
|
||||
try:
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
# Handle cancellation notifications
|
||||
if isinstance(notification.root, CancelledNotification):
|
||||
cancelled_id = notification.root.params.requestId
|
||||
if cancelled_id in self._in_flight:
|
||||
self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
self._received_notification(notification)
|
||||
self._handle_incoming(notification)
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)
|
||||
else: # Response or error
|
||||
response_queue = self._response_streams.get(message.message.root.id)
|
||||
if response_queue is not None:
|
||||
response_queue.put(message.message.root)
|
||||
else:
|
||||
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception:
|
||||
logger.exception("Error in message processing loop")
|
||||
raise
|
||||
|
||||
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]):
|
||||
"""
|
||||
Can be overridden by subclasses to handle a request without needing to
|
||||
listen on the message stream.
|
||||
|
||||
If the request is responded to within this method, it will not be
|
||||
forwarded on to the message stream.
|
||||
"""
|
||||
|
||||
def _received_notification(self, notification: ReceiveNotificationT):
|
||||
"""
|
||||
Can be overridden by subclasses to handle a notification without needing
|
||||
to listen on the message stream.
|
||||
"""
|
||||
|
||||
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
|
||||
"""
|
||||
Sends a progress notification for a request that is currently being
|
||||
processed.
|
||||
"""
|
||||
|
||||
def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
|
||||
):
|
||||
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
||||
368
dify/api/core/mcp/session/client_session.py
Normal file
368
dify/api/core/mcp/session/client_session.py
Normal file
@@ -0,0 +1,368 @@
|
||||
import queue
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp import types
|
||||
from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
|
||||
from core.mcp.session.base_session import BaseSession, RequestResponder
|
||||
|
||||
DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.project.version)
|
||||
|
||||
|
||||
class SamplingFnT(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData: ...
|
||||
|
||||
|
||||
class ListRootsFnT(Protocol):
|
||||
def __call__(self, context: RequestContext["ClientSession", Any]) -> types.ListRootsResult | types.ErrorData: ...
|
||||
|
||||
|
||||
class LoggingFnT(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
): ...
|
||||
|
||||
|
||||
class MessageHandlerFnT(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
): ...
|
||||
|
||||
|
||||
def _default_message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
):
|
||||
if isinstance(message, Exception):
|
||||
raise ValueError(str(message))
|
||||
elif isinstance(message, (types.ServerNotification | RequestResponder)):
|
||||
pass
|
||||
|
||||
|
||||
def _default_sampling_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Sampling not supported",
|
||||
)
|
||||
|
||||
|
||||
def _default_list_roots_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
) -> types.ListRootsResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="List roots not supported",
|
||||
)
|
||||
|
||||
|
||||
def _default_logging_callback(
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
|
||||
|
||||
|
||||
class ClientSession(
|
||||
BaseSession[
|
||||
types.ClientRequest,
|
||||
types.ClientNotification,
|
||||
types.ClientResult,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: queue.Queue,
|
||||
write_stream: queue.Queue,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
logging_callback: LoggingFnT | None = None,
|
||||
message_handler: MessageHandlerFnT | None = None,
|
||||
client_info: types.Implementation | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
read_stream,
|
||||
write_stream,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
self._client_info = client_info or DEFAULT_CLIENT_INFO
|
||||
self._sampling_callback = sampling_callback or _default_sampling_callback
|
||||
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
|
||||
self._logging_callback = logging_callback or _default_logging_callback
|
||||
self._message_handler = message_handler or _default_message_handler
|
||||
|
||||
def initialize(self) -> types.InitializeResult:
|
||||
# Only set capabilities if non-default callbacks are provided
|
||||
# This prevents servers from attempting callbacks when we don't actually support them
|
||||
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
|
||||
roots = (
|
||||
types.RootsCapability(
|
||||
# Only enable listChanged if we have a custom callback
|
||||
listChanged=True,
|
||||
)
|
||||
if self._list_roots_callback is not _default_list_roots_callback
|
||||
else None
|
||||
)
|
||||
|
||||
result = self.send_request(
|
||||
types.ClientRequest(
|
||||
types.InitializeRequest(
|
||||
method="initialize",
|
||||
params=types.InitializeRequestParams(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=types.ClientCapabilities(
|
||||
sampling=sampling,
|
||||
experimental=None,
|
||||
roots=roots,
|
||||
),
|
||||
clientInfo=self._client_info,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.InitializeResult,
|
||||
)
|
||||
|
||||
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
|
||||
|
||||
self.send_notification(
|
||||
types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def send_ping(self) -> types.EmptyResult:
|
||||
"""Send a ping request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
|
||||
"""Send a progress notification."""
|
||||
self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.ProgressNotification(
|
||||
method="notifications/progress",
|
||||
params=types.ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
|
||||
"""Send a logging/setLevel request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.SetLevelRequest(
|
||||
method="logging/setLevel",
|
||||
params=types.SetLevelRequestParams(level=level),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def list_resources(self) -> types.ListResourcesResult:
|
||||
"""Send a resources/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListResourcesRequest(
|
||||
method="resources/list",
|
||||
)
|
||||
),
|
||||
types.ListResourcesResult,
|
||||
)
|
||||
|
||||
def list_resource_templates(self) -> types.ListResourceTemplatesResult:
|
||||
"""Send a resources/templates/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListResourceTemplatesRequest(
|
||||
method="resources/templates/list",
|
||||
)
|
||||
),
|
||||
types.ListResourceTemplatesResult,
|
||||
)
|
||||
|
||||
def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
|
||||
"""Send a resources/read request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ReadResourceRequest(
|
||||
method="resources/read",
|
||||
params=types.ReadResourceRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.ReadResourceResult,
|
||||
)
|
||||
|
||||
def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/subscribe request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.SubscribeRequest(
|
||||
method="resources/subscribe",
|
||||
params=types.SubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/unsubscribe request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.UnsubscribeRequest(
|
||||
method="resources/unsubscribe",
|
||||
params=types.UnsubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
) -> types.CallToolResult:
|
||||
"""Send a tools/call request."""
|
||||
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=types.CallToolRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
request_read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
|
||||
def list_prompts(self) -> types.ListPromptsResult:
|
||||
"""Send a prompts/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListPromptsRequest(
|
||||
method="prompts/list",
|
||||
)
|
||||
),
|
||||
types.ListPromptsResult,
|
||||
)
|
||||
|
||||
def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
|
||||
"""Send a prompts/get request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.GetPromptRequest(
|
||||
method="prompts/get",
|
||||
params=types.GetPromptRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
types.GetPromptResult,
|
||||
)
|
||||
|
||||
def complete(
|
||||
self,
|
||||
ref: types.ResourceTemplateReference | types.PromptReference,
|
||||
argument: dict[str, str],
|
||||
) -> types.CompleteResult:
|
||||
"""Send a completion/complete request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CompleteRequest(
|
||||
method="completion/complete",
|
||||
params=types.CompleteRequestParams(
|
||||
ref=ref,
|
||||
argument=types.CompletionArgument.model_validate(argument),
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CompleteResult,
|
||||
)
|
||||
|
||||
def list_tools(self) -> types.ListToolsResult:
|
||||
"""Send a tools/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListToolsRequest(
|
||||
method="tools/list",
|
||||
)
|
||||
),
|
||||
types.ListToolsResult,
|
||||
)
|
||||
|
||||
def send_roots_list_changed(self):
|
||||
"""Send a roots/list_changed notification."""
|
||||
self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.RootsListChangedNotification(
|
||||
method="notifications/roots/list_changed",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]):
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
meta=responder.request_meta,
|
||||
session=self,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
match responder.request.root:
|
||||
case types.CreateMessageRequest(params=params):
|
||||
with responder:
|
||||
response = self._sampling_callback(ctx, params)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
responder.respond(client_response)
|
||||
|
||||
case types.ListRootsRequest():
|
||||
with responder:
|
||||
list_roots_response = self._list_roots_callback(ctx)
|
||||
client_response = ClientResponse.validate_python(list_roots_response)
|
||||
responder.respond(client_response)
|
||||
|
||||
case types.PingRequest():
|
||||
with responder:
|
||||
return responder.respond(types.ClientResult(root=types.EmptyResult()))
|
||||
|
||||
def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
):
|
||||
"""Handle incoming messages by forwarding to the message handler."""
|
||||
self._message_handler(req)
|
||||
|
||||
def _received_notification(self, notification: types.ServerNotification):
|
||||
"""Handle notifications from the server."""
|
||||
# Process specific notification types
|
||||
match notification.root:
|
||||
case types.LoggingMessageNotification(params=params):
|
||||
self._logging_callback(params)
|
||||
case _:
|
||||
pass
|
||||
1342
dify/api/core/mcp/types.py
Normal file
1342
dify/api/core/mcp/types.py
Normal file
File diff suppressed because it is too large
Load Diff
142
dify/api/core/mcp/utils.py
Normal file
142
dify/api/core/mcp/utils.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from contextlib import AbstractContextManager
|
||||
|
||||
import httpx
|
||||
import httpx_sse
|
||||
from httpx_sse import connect_sse
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp.types import ErrorData, JSONRPCError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
|
||||
def create_ssrf_proxy_mcp_http_client(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
) -> httpx.Client:
|
||||
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
|
||||
|
||||
Args:
|
||||
headers: Optional headers to include in the client
|
||||
timeout: Optional timeout configuration
|
||||
|
||||
Returns:
|
||||
Configured httpx.Client with proxy settings
|
||||
"""
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
proxy=dify_config.SSRF_PROXY_ALL_URL,
|
||||
)
|
||||
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxy_mounts = {
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
|
||||
"https://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
),
|
||||
}
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
mounts=proxy_mounts,
|
||||
)
|
||||
else:
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
|
||||
def ssrf_proxy_sse_connect(url: str, **kwargs) -> AbstractContextManager[httpx_sse.EventSource]:
|
||||
"""Connect to SSE endpoint with SSRF proxy protection.
|
||||
|
||||
This function creates an SSE connection using the configured proxy settings
|
||||
to prevent SSRF attacks when connecting to external endpoints. It returns
|
||||
a context manager that yields an EventSource object for SSE streaming.
|
||||
|
||||
The function handles HTTP client creation and cleanup automatically, but
|
||||
also accepts a pre-configured client via kwargs.
|
||||
|
||||
Args:
|
||||
url (str): The SSE endpoint URL to connect to
|
||||
**kwargs: Additional arguments passed to the SSE connection, including:
|
||||
- client (httpx.Client, optional): Pre-configured HTTP client.
|
||||
If not provided, one will be created with SSRF protection.
|
||||
- method (str, optional): HTTP method to use, defaults to "GET"
|
||||
- headers (dict, optional): HTTP headers to include in the request
|
||||
- timeout (httpx.Timeout, optional): Timeout configuration for the connection
|
||||
|
||||
Returns:
|
||||
AbstractContextManager[httpx_sse.EventSource]: A context manager that yields an EventSource
|
||||
object for SSE streaming. The EventSource provides access to server-sent events.
|
||||
|
||||
Example:
|
||||
```python
|
||||
with ssrf_proxy_sse_connect(url, headers=headers) as event_source:
|
||||
for sse in event_source.iter_sse():
|
||||
print(sse.event, sse.data)
|
||||
```
|
||||
|
||||
Note:
|
||||
If a client is not provided in kwargs, one will be automatically created
|
||||
with SSRF protection based on the application's configuration. If an
|
||||
exception occurs during connection, any automatically created client
|
||||
will be cleaned up automatically.
|
||||
"""
|
||||
|
||||
# Extract client if provided, otherwise create one
|
||||
client = kwargs.pop("client", None)
|
||||
if client is None:
|
||||
# Create client with SSRF proxy configuration
|
||||
timeout = kwargs.pop(
|
||||
"timeout",
|
||||
httpx.Timeout(
|
||||
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
|
||||
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
|
||||
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
|
||||
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||
),
|
||||
)
|
||||
headers = kwargs.pop("headers", {})
|
||||
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
|
||||
client_provided = False
|
||||
else:
|
||||
client_provided = True
|
||||
|
||||
# Extract method if provided, default to GET
|
||||
method = kwargs.pop("method", "GET")
|
||||
|
||||
try:
|
||||
return connect_sse(client, method, url, **kwargs)
|
||||
except Exception:
|
||||
# If we created the client, we need to clean it up on error
|
||||
if not client_provided:
|
||||
client.close()
|
||||
raise
|
||||
|
||||
|
||||
def create_mcp_error_response(
|
||||
request_id: int | str | None, code: int, message: str, data=None
|
||||
) -> Generator[bytes, None, None]:
|
||||
"""Create MCP error response"""
|
||||
error_data = ErrorData(code=code, message=message, data=data)
|
||||
json_response = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id or 1,
|
||||
error=error_data,
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
sse_content = json_data.encode()
|
||||
yield sse_content
|
||||
Reference in New Issue
Block a user