This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

View File

@@ -0,0 +1,705 @@
import base64
import hashlib
import json
import os
import secrets
import urllib.parse
from urllib.parse import urljoin, urlparse
import httpx
from httpx import RequestError
from pydantic import ValidationError
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
from core.helper import ssrf_proxy
from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState
from core.mcp.error import MCPRefreshTokenError
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
ProtectedResourceMetadata,
)
from extensions.ext_redis import redis_client
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
def generate_pkce_challenge() -> tuple[str, str]:
"""Generate PKCE challenge and verifier."""
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8")
code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
return code_verifier, code_challenge
def build_protected_resource_metadata_discovery_urls(
www_auth_resource_metadata_url: str | None, server_url: str
) -> list[str]:
"""
Build a list of URLs to try for Protected Resource Metadata discovery.
Per SEP-985, supports fallback when discovery fails at one URL.
"""
urls = []
# First priority: URL from WWW-Authenticate header
if www_auth_resource_metadata_url:
urls.append(www_auth_resource_metadata_url)
# Fallback: construct from server URL
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
if fallback_url not in urls:
urls.append(fallback_url)
return urls
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
"""
Build a list of URLs to try for OAuth Authorization Server Metadata discovery.
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
Per RFC 8414 section 3:
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
Example:
- issuer: https://example.com/oauth
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
"""
urls = []
base_url = auth_server_url or server_url
parsed = urlparse(base_url)
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/") # Remove trailing slash
# Try OpenID Connect discovery first (more common)
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
# Include the path component if present in the issuer URL
if path:
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
else:
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
return urls
def discover_protected_resource_metadata(
prm_url: str | None, server_url: str, protocol_version: str | None = None
) -> ProtectedResourceMetadata | None:
"""Discover OAuth 2.0 Protected Resource Metadata (RFC 9470)."""
urls = build_protected_resource_metadata_discovery_urls(prm_url, server_url)
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
for url in urls:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 200:
return ProtectedResourceMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
continue # Try next URL
return None
def discover_oauth_authorization_server_metadata(
auth_server_url: str | None, server_url: str, protocol_version: str | None = None
) -> OAuthMetadata | None:
"""Discover OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url)
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
for url in urls:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 200:
return OAuthMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
continue # Try next URL
return None
def get_effective_scope(
scope_from_www_auth: str | None,
prm: ProtectedResourceMetadata | None,
asm: OAuthMetadata | None,
client_scope: str | None,
) -> str | None:
"""
Determine effective scope using priority-based selection strategy.
Priority order:
1. WWW-Authenticate header scope (server explicit requirement)
2. Protected Resource Metadata scopes
3. OAuth Authorization Server Metadata scopes
4. Client configured scope
"""
if scope_from_www_auth:
return scope_from_www_auth
if prm and prm.scopes_supported:
return " ".join(prm.scopes_supported)
if asm and asm.scopes_supported:
return " ".join(asm.scopes_supported)
return client_scope
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
# Generate a secure random state key
state_key = secrets.token_urlsafe(32)
# Store the state data in Redis with expiration
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
return state_key
def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
"""Retrieve and decode OAuth state data from Redis using the state key, then delete it."""
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
# Get state data from Redis
state_data = redis_client.get(redis_key)
if not state_data:
raise ValueError("State parameter has expired or does not exist")
# Delete the state data from Redis immediately after retrieval to prevent reuse
redis_client.delete(redis_key)
try:
# Parse and validate the state data
oauth_state = OAuthCallbackState.model_validate_json(state_data)
return oauth_state
except ValidationError as e:
raise ValueError(f"Invalid state parameter: {str(e)}")
def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]:
"""
Handle the callback from the OAuth provider.
Returns:
A tuple of (callback_state, tokens) that can be used by the caller to save data.
"""
# Retrieve state data from Redis (state is automatically deleted after retrieval)
full_state_data = _retrieve_redis_state(state_key)
tokens = exchange_authorization(
full_state_data.server_url,
full_state_data.metadata,
full_state_data.client_information,
authorization_code,
full_state_data.code_verifier,
full_state_data.redirect_uri,
)
return full_state_data, tokens
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
"""Check if the server supports OAuth 2.0 Resource Discovery."""
b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource"
if b_query:
url_for_resource_discovery += f"?{b_query}"
if b_fragment:
url_for_resource_discovery += f"#{b_fragment}"
try:
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
response = ssrf_proxy.get(url_for_resource_discovery, headers=headers)
if 200 <= response.status_code < 300:
body = response.json()
# Support both singular and plural forms
if body.get("authorization_servers"):
return True, body["authorization_servers"][0]
elif body.get("authorization_server_url"):
return True, body["authorization_server_url"][0]
else:
return False, ""
return False, ""
except RequestError:
# Not support resource discovery, fall back to well-known OAuth metadata
return False, ""
def discover_oauth_metadata(
server_url: str,
resource_metadata_url: str | None = None,
scope_hint: str | None = None,
protocol_version: str | None = None,
) -> tuple[OAuthMetadata | None, ProtectedResourceMetadata | None, str | None]:
"""
Discover OAuth metadata using RFC 8414/9470 standards.
Args:
server_url: The MCP server URL
resource_metadata_url: Protected Resource Metadata URL from WWW-Authenticate header
scope_hint: Scope hint from WWW-Authenticate header
protocol_version: MCP protocol version
Returns:
(oauth_metadata, protected_resource_metadata, scope_hint)
"""
# Discover Protected Resource Metadata
prm = discover_protected_resource_metadata(resource_metadata_url, server_url, protocol_version)
# Get authorization server URL from PRM or use server URL
auth_server_url = None
if prm and prm.authorization_servers:
auth_server_url = prm.authorization_servers[0]
# Discover OAuth Authorization Server Metadata
asm = discover_oauth_authorization_server_metadata(auth_server_url, server_url, protocol_version)
return asm, prm, scope_hint
def start_authorization(
server_url: str,
metadata: OAuthMetadata | None,
client_information: OAuthClientInformation,
redirect_url: str,
provider_id: str,
tenant_id: str,
scope: str | None = None,
) -> tuple[str, str]:
"""Begins the authorization flow with secure Redis state storage."""
response_type = "code"
code_challenge_method = "S256"
if metadata:
authorization_url = metadata.authorization_endpoint
if response_type not in metadata.response_types_supported:
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
else:
authorization_url = urljoin(server_url, "/authorize")
code_verifier, code_challenge = generate_pkce_challenge()
# Prepare state data with all necessary information
state_data = OAuthCallbackState(
provider_id=provider_id,
tenant_id=tenant_id,
server_url=server_url,
metadata=metadata,
client_information=client_information,
code_verifier=code_verifier,
redirect_uri=redirect_url,
)
# Store state data in Redis and generate secure state key
state_key = _create_secure_redis_state(state_data)
params = {
"response_type": response_type,
"client_id": client_information.client_id,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"redirect_uri": redirect_url,
"state": state_key,
}
# Add scope if provided
if scope:
params["scope"] = scope
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
return authorization_url, code_verifier
def _parse_token_response(response: httpx.Response) -> OAuthTokens:
"""
Parse OAuth token response supporting both JSON and form-urlencoded formats.
Per RFC 6749 Section 5.1, the standard format is JSON.
However, some legacy OAuth providers (e.g., early GitHub OAuth Apps) return
application/x-www-form-urlencoded format for backwards compatibility.
Args:
response: The HTTP response from token endpoint
Returns:
Parsed OAuth tokens
Raises:
ValueError: If response cannot be parsed
"""
content_type = response.headers.get("content-type", "").lower()
if "application/json" in content_type:
# Standard OAuth 2.0 JSON response (RFC 6749)
return OAuthTokens.model_validate(response.json())
elif "application/x-www-form-urlencoded" in content_type:
# Legacy form-urlencoded response (non-standard but used by some providers)
token_data = dict(urllib.parse.parse_qsl(response.text))
return OAuthTokens.model_validate(token_data)
else:
# No content-type or unknown - try JSON first, fallback to form-urlencoded
try:
return OAuthTokens.model_validate(response.json())
except (ValidationError, json.JSONDecodeError):
token_data = dict(urllib.parse.parse_qsl(response.text))
return OAuthTokens.model_validate(token_data)
def exchange_authorization(
server_url: str,
metadata: OAuthMetadata | None,
client_information: OAuthClientInformation,
authorization_code: str,
code_verifier: str,
redirect_uri: str,
) -> OAuthTokens:
"""Exchanges an authorization code for an access token."""
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
params = {
"grant_type": grant_type,
"client_id": client_information.client_id,
"code": authorization_code,
"code_verifier": code_verifier,
"redirect_uri": redirect_uri,
}
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = ssrf_proxy.post(token_url, data=params)
if not response.is_success:
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
return _parse_token_response(response)
def refresh_authorization(
server_url: str,
metadata: OAuthMetadata | None,
client_information: OAuthClientInformation,
refresh_token: str,
) -> OAuthTokens:
"""Exchange a refresh token for an updated access token."""
grant_type = MCPSupportGrantType.REFRESH_TOKEN.value
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
params = {
"grant_type": grant_type,
"client_id": client_information.client_id,
"refresh_token": refresh_token,
}
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
try:
response = ssrf_proxy.post(token_url, data=params)
except ssrf_proxy.MaxRetriesExceededError as e:
raise MCPRefreshTokenError(e) from e
if not response.is_success:
raise MCPRefreshTokenError(response.text)
return _parse_token_response(response)
def client_credentials_flow(
server_url: str,
metadata: OAuthMetadata | None,
client_information: OAuthClientInformation,
scope: str | None = None,
) -> OAuthTokens:
"""Execute Client Credentials Flow to get access token."""
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
# Support both Basic Auth and body parameters for client authentication
headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {"grant_type": grant_type}
if scope:
data["scope"] = scope
# If client_secret is provided, use Basic Auth (preferred method)
if client_information.client_secret:
credentials = f"{client_information.client_id}:{client_information.client_secret}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded_credentials}"
else:
# Fall back to including credentials in the body
data["client_id"] = client_information.client_id
if client_information.client_secret:
data["client_secret"] = client_information.client_secret
response = ssrf_proxy.post(token_url, headers=headers, data=data)
if not response.is_success:
raise ValueError(
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
)
return _parse_token_response(response)
def register_client(
server_url: str,
metadata: OAuthMetadata | None,
client_metadata: OAuthClientMetadata,
) -> OAuthClientInformationFull:
"""Performs OAuth 2.0 Dynamic Client Registration."""
if metadata:
if not metadata.registration_endpoint:
raise ValueError("Incompatible auth server: does not support dynamic client registration")
registration_url = metadata.registration_endpoint
else:
registration_url = urljoin(server_url, "/register")
response = ssrf_proxy.post(
registration_url,
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
)
if not response.is_success:
response.raise_for_status()
return OAuthClientInformationFull.model_validate(response.json())
def auth(
provider: MCPProviderEntity,
authorization_code: str | None = None,
state_param: str | None = None,
resource_metadata_url: str | None = None,
scope_hint: str | None = None,
) -> AuthResult:
"""
Orchestrates the full auth flow with a server using secure Redis state storage.
This function performs only network operations and returns actions that need
to be performed by the caller (such as saving data to database).
Args:
provider: The MCP provider entity
authorization_code: Optional authorization code from OAuth callback
state_param: Optional state parameter from OAuth callback
resource_metadata_url: Optional Protected Resource Metadata URL from WWW-Authenticate
scope_hint: Optional scope hint from WWW-Authenticate header
Returns:
AuthResult containing actions to be performed and response data
"""
actions: list[AuthAction] = []
server_url = provider.decrypt_server_url()
# Discover OAuth metadata using RFC 8414/9470 standards
server_metadata, prm, scope_from_www_auth = discover_oauth_metadata(
server_url, resource_metadata_url, scope_hint, LATEST_PROTOCOL_VERSION
)
client_metadata = provider.client_metadata
provider_id = provider.id
tenant_id = provider.tenant_id
client_information = provider.retrieve_client_information()
redirect_url = provider.redirect_url
credentials = provider.decrypt_credentials()
# Determine grant type based on server metadata
if not server_metadata:
raise ValueError("Failed to discover OAuth metadata from server")
supported_grant_types = server_metadata.grant_types_supported or []
# Convert to lowercase for comparison
supported_grant_types_lower = [gt.lower() for gt in supported_grant_types]
# Determine which grant type to use
effective_grant_type = None
if MCPSupportGrantType.AUTHORIZATION_CODE.value in supported_grant_types_lower:
effective_grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
else:
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
# Determine effective scope using priority-based strategy
effective_scope = get_effective_scope(scope_from_www_auth, prm, server_metadata, credentials.get("scope"))
if not client_information:
if authorization_code is not None:
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
# For client credentials flow, we don't need to register client dynamically
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Client should provide client_id and client_secret directly
raise ValueError("Client credentials flow requires client_id and client_secret to be provided")
try:
full_information = register_client(server_url, server_metadata, client_metadata)
except RequestError as e:
raise ValueError(f"Could not register OAuth client: {e}")
# Return action to save client information
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_CLIENT_INFO,
data={"client_information": full_information.model_dump()},
provider_id=provider_id,
tenant_id=tenant_id,
)
)
client_information = full_information
# Handle client credentials flow
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Direct token request without user interaction
try:
tokens = client_credentials_flow(
server_url,
server_metadata,
client_information,
effective_scope,
)
# Return action to save tokens and grant type
token_data = tokens.model_dump()
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=token_data,
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
# KeyError: Missing required fields in response
raise ValueError(f"Client credentials flow failed: {e}")
# Exchange authorization code for tokens (Authorization Code flow)
if authorization_code is not None:
if not state_param:
raise ValueError("State parameter is required when exchanging authorization code")
try:
# Retrieve state data from Redis using state key
full_state_data = _retrieve_redis_state(state_param)
code_verifier = full_state_data.code_verifier
redirect_uri = full_state_data.redirect_uri
if not code_verifier or not redirect_uri:
raise ValueError("Missing code_verifier or redirect_uri in state data")
except (json.JSONDecodeError, ValueError) as e:
raise ValueError(f"Invalid state parameter: {e}")
tokens = exchange_authorization(
server_url,
server_metadata,
client_information,
authorization_code,
code_verifier,
redirect_uri,
)
# Return action to save tokens
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=tokens.model_dump(),
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
provider_tokens = provider.retrieve_tokens()
# Handle token refresh or new authorization
if provider_tokens and provider_tokens.refresh_token:
try:
new_tokens = refresh_authorization(
server_url, server_metadata, client_information, provider_tokens.refresh_token
)
# Return action to save new tokens
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=new_tokens.model_dump(),
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
# KeyError: Missing required fields in response
raise ValueError(f"Could not refresh OAuth tokens: {e}")
# Start new authorization flow (only for authorization code flow)
authorization_url, code_verifier = start_authorization(
server_url,
server_metadata,
client_information,
redirect_url,
provider_id,
tenant_id,
effective_scope,
)
# Return action to save code verifier
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_CODE_VERIFIER,
data={"code_verifier": code_verifier},
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"authorization_url": authorization_url})

View File

@@ -0,0 +1,197 @@
"""
MCP Client with Authentication Retry Support
This module provides an enhanced MCPClient that automatically handles
authentication failures and retries operations after refreshing tokens.
"""
import logging
from collections.abc import Callable
from typing import Any
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, Tool
from extensions.ext_database import db
logger = logging.getLogger(__name__)
class MCPClientWithAuthRetry(MCPClient):
"""
An enhanced MCPClient that provides automatic authentication retry.
This class extends MCPClient and intercepts MCPAuthError exceptions
to refresh authentication before retrying failed operations.
Note: This class uses lazy session creation - database sessions are only
created when authentication retry is actually needed, not on every request.
"""
def __init__(
self,
server_url: str,
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
provider_entity: MCPProviderEntity | None = None,
authorization_code: str | None = None,
by_server_id: bool = False,
):
"""
Initialize the MCP client with auth retry capability.
Args:
server_url: The MCP server URL
headers: Optional headers for requests
timeout: Request timeout
sse_read_timeout: SSE read timeout
provider_entity: Provider entity for authentication
authorization_code: Optional authorization code for initial auth
by_server_id: Whether to look up provider by server ID
"""
super().__init__(server_url, headers, timeout, sse_read_timeout)
self.provider_entity = provider_entity
self.authorization_code = authorization_code
self.by_server_id = by_server_id
self._has_retried = False
def _handle_auth_error(self, error: MCPAuthError) -> None:
"""
Handle authentication error by refreshing tokens.
This method creates a short-lived database session only when authentication
retry is needed, minimizing database connection hold time.
Args:
error: The authentication error
Raises:
MCPAuthError: If authentication fails or max retries reached
"""
if not self.provider_entity:
raise error
if self._has_retried:
raise error
self._has_retried = True
try:
# Create a temporary session only for auth retry
# This session is short-lived and only exists during the auth operation
from services.tools.mcp_tools_manage_service import MCPToolManageService
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
# Perform authentication using the service's auth method
# Extract OAuth metadata hints from the error
mcp_service.auth_with_actions(
self.provider_entity,
self.authorization_code,
resource_metadata_url=error.resource_metadata_url,
scope_hint=error.scope_hint,
)
# Retrieve new tokens
self.provider_entity = mcp_service.get_provider_entity(
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
)
# Session is closed here, before we update headers
token = self.provider_entity.retrieve_tokens()
if not token:
raise MCPAuthError("Authentication failed - no token received")
# Update headers with new token
self.headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
# Clear authorization code after first use
self.authorization_code = None
except MCPAuthError:
# Re-raise MCPAuthError as is
raise
except Exception as e:
# Catch all exceptions during auth retry
logger.exception("Authentication retry failed")
raise MCPAuthError(f"Authentication retry failed: {e}") from e
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
"""
Execute a function with authentication retry logic.
Args:
func: The function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
The result of the function call
Raises:
MCPAuthError: If authentication fails after retries
Any other exceptions from the function
"""
try:
return func(*args, **kwargs)
except MCPAuthError as e:
self._handle_auth_error(e)
# Re-initialize the connection with new headers
if self._initialized:
# Clean up existing connection
self._exit_stack.close()
self._session = None
self._initialized = False
# Re-initialize with new headers
self._initialize()
self._initialized = True
return func(*args, **kwargs)
finally:
# Reset retry flag after operation completes
self._has_retried = False
def __enter__(self):
"""Enter the context manager with retry support."""
def initialize_with_retry():
super(MCPClientWithAuthRetry, self).__enter__()
return self
return self._execute_with_retry(initialize_with_retry)
def list_tools(self) -> list[Tool]:
"""
List available tools from the MCP server with auth retry.
Returns:
List of available tools
Raises:
MCPAuthError: If authentication fails after retries
"""
return self._execute_with_retry(super().list_tools)
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""
Invoke a tool on the MCP server with auth retry.
Args:
tool_name: Name of the tool to invoke
tool_args: Arguments for the tool
Returns:
Result of the tool invocation
Raises:
MCPAuthError: If authentication fails after retries
"""
return self._execute_with_retry(super().invoke_tool, tool_name, tool_args)

View File

@@ -0,0 +1,360 @@
import logging
import queue
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, TypeAlias, final
from urllib.parse import urljoin, urlparse
import httpx
from httpx_sse import EventSource, ServerSentEvent
from sseclient import SSEClient
from core.mcp import types
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.types import SessionMessage
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
logger = logging.getLogger(__name__)
DEFAULT_QUEUE_READ_TIMEOUT = 3
@final
class _StatusReady:
def __init__(self, endpoint_url: str):
self.endpoint_url = endpoint_url
@final
class _StatusError:
def __init__(self, exc: Exception):
self.exc = exc
# Type aliases for better readability
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
class SSETransport:
"""SSE client transport implementation."""
def __init__(
self,
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 1 * 60,
):
"""Initialize the SSE transport.
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
"""
self.url = url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.endpoint_url: str | None = None
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
"""Validate that the endpoint URL matches the connection origin.
Args:
endpoint_url: The endpoint URL to validate.
Returns:
True if valid, False otherwise.
"""
url_parsed = urlparse(self.url)
endpoint_parsed = urlparse(endpoint_url)
return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue):
"""Handle an 'endpoint' SSE event.
Args:
sse_data: The SSE event data.
status_queue: Queue to put status updates.
"""
endpoint_url = urljoin(self.url, sse_data)
logger.info("Received endpoint URL: %s", endpoint_url)
if not self._validate_endpoint_url(endpoint_url):
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
logger.error(error_msg)
status_queue.put(_StatusError(ValueError(error_msg)))
return
status_queue.put(_StatusReady(endpoint_url))
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue):
"""Handle a 'message' SSE event.
Args:
sse_data: The SSE event data.
read_queue: Queue to put parsed messages.
"""
try:
message = types.JSONRPCMessage.model_validate_json(sse_data)
logger.debug("Received server message: %s", message)
session_message = SessionMessage(message)
read_queue.put(session_message)
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue):
"""Handle a single SSE event.
Args:
sse: The SSE event object.
read_queue: Queue for message events.
status_queue: Queue for status events.
"""
match sse.event:
case "endpoint":
self._handle_endpoint_event(sse.data, status_queue)
case "message":
self._handle_message_event(sse.data, read_queue)
case _:
logger.warning("Unknown SSE event: %s", sse.event)
def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue):
"""Read and process SSE events.
Args:
event_source: The SSE event source.
read_queue: Queue to put received messages.
status_queue: Queue to put status updates.
"""
try:
for sse in event_source.iter_sse():
self._handle_sse_event(sse, read_queue, status_queue)
except httpx.ReadError as exc:
logger.debug("SSE reader shutting down normally: %s", exc)
except Exception as exc:
read_queue.put(exc)
finally:
read_queue.put(None)
def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage):
"""Send a single message to the server.
Args:
client: HTTP client to use.
endpoint_url: The endpoint URL to send to.
message: The message to send.
"""
response = client.post(
endpoint_url,
json=message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug("Client message sent successfully: %s", response.status_code)
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue):
"""Handle writing messages to the server.
Args:
client: HTTP client to use.
endpoint_url: The endpoint URL to send messages to.
write_queue: Queue to read messages from.
"""
try:
while True:
try:
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if message is None:
break
if isinstance(message, Exception):
write_queue.put(message)
continue
self._send_message(client, endpoint_url, message)
except queue.Empty:
continue
except httpx.ReadError as exc:
logger.debug("Post writer shutting down normally: %s", exc)
except Exception as exc:
logger.exception("Error writing messages")
write_queue.put(exc)
finally:
write_queue.put(None)
def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
"""Wait for the endpoint URL from the status queue.
Args:
status_queue: Queue to read status from.
Returns:
The endpoint URL.
Raises:
ValueError: If endpoint URL is not received or there's an error.
"""
try:
status = status_queue.get(timeout=1)
except queue.Empty:
raise ValueError("failed to get endpoint URL")
if isinstance(status, _StatusReady):
return status.endpoint_url
elif isinstance(status, _StatusError):
raise status.exc
else:
raise ValueError("failed to get endpoint URL")
def connect(
self,
executor: ThreadPoolExecutor,
client: httpx.Client,
event_source: EventSource,
) -> tuple[ReadQueue, WriteQueue]:
"""Establish connection and start worker threads.
Args:
executor: Thread pool executor.
client: HTTP client.
event_source: SSE event source.
Returns:
Tuple of (read_queue, write_queue).
"""
read_queue: ReadQueue = queue.Queue()
write_queue: WriteQueue = queue.Queue()
status_queue: StatusQueue = queue.Queue()
# Start SSE reader thread
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
# Wait for endpoint URL
endpoint_url = self._wait_for_endpoint(status_queue)
self.endpoint_url = endpoint_url
# Start post writer thread
executor.submit(self.post_writer, client, endpoint_url, write_queue)
return read_queue, write_queue
@contextmanager
def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 1 * 60,
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
"""
Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
Yields:
Tuple of (read_queue, write_queue) for message communication.
"""
transport = SSETransport(url, headers, timeout, sse_read_timeout)
read_queue: ReadQueue | None = None
write_queue: WriteQueue | None = None
executor = ThreadPoolExecutor()
try:
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect(
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
read_queue, write_queue = transport.connect(executor, client, event_source)
yield read_queue, write_queue
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError(response=exc.response)
raise MCPConnectionError()
except Exception:
logger.exception("Error connecting to SSE endpoint")
raise
finally:
# Clean up queues
if read_queue:
read_queue.put(None)
if write_queue:
write_queue.put(None)
# Shutdown executor without waiting to prevent hanging
executor.shutdown(wait=False)
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):
"""
Send a message to the server using the provided HTTP client.
Args:
http_client: The HTTP client to use for sending
endpoint_url: The endpoint URL to send the message to
session_message: The message to send
"""
try:
response = http_client.post(
endpoint_url,
json=session_message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug("Client message sent successfully: %s", response.status_code)
except Exception:
logger.exception("Error sending message")
raise
def read_messages(
sse_client: SSEClient,
) -> Generator[SessionMessage | Exception, None, None]:
"""
Read messages from the SSE client.
Args:
sse_client: The SSE client to read from
Yields:
SessionMessage or Exception for each event received
"""
try:
for sse in sse_client.events():
if sse.event == "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug("Received server message: %s", message)
yield SessionMessage(message)
except Exception as exc:
logger.exception("Error parsing server message")
yield exc
else:
logger.warning("Unknown SSE event: %s", sse.event)
except Exception as exc:
logger.exception("Error reading SSE messages")
yield exc

View File

@@ -0,0 +1,485 @@
"""
StreamableHTTP Client Transport Module
This module implements the StreamableHTTP transport for MCP clients,
providing support for HTTP POST requests with optional SSE streaming responses
and session management.
"""
import logging
import queue
from collections.abc import Callable, Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, cast
import httpx
from httpx_sse import EventSource, ServerSentEvent
from core.mcp.types import (
ClientMessageMetadata,
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
RequestId,
SessionMessage,
)
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
logger = logging.getLogger(__name__)
SessionMessageOrError = SessionMessage | Exception | None
# Queue types with clearer names for their roles
ServerToClientQueue = queue.Queue[SessionMessageOrError] # Server to client messages
ClientToServerQueue = queue.Queue[SessionMessage | None] # Client to server messages
GetSessionIdCallback = Callable[[], str | None]
MCP_SESSION_ID = "mcp-session-id"
LAST_EVENT_ID = "last-event-id"
CONTENT_TYPE = "content-type"
ACCEPT = "Accept"
JSON = "application/json"
SSE = "text/event-stream"
DEFAULT_QUEUE_READ_TIMEOUT = 3
class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""
class ResumptionError(StreamableHTTPError):
"""Raised when resumption request is invalid."""
@dataclass
class RequestContext:
"""Context for a request operation."""
client: httpx.Client
headers: dict[str, str]
session_id: str | None
session_message: SessionMessage
metadata: ClientMessageMetadata | None
server_to_client_queue: ServerToClientQueue # Renamed for clarity
sse_read_timeout: float
class StreamableHTTPTransport:
"""StreamableHTTP client transport implementation."""
def __init__(
self,
url: str,
headers: dict[str, Any] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
):
"""Initialize the StreamableHTTP transport.
Args:
url: The endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
"""
self.url = url
self.headers = headers or {}
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
self.sse_read_timeout = (
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
)
self.session_id: str | None = None
self.request_headers = {
ACCEPT: f"{JSON}, {SSE}",
CONTENT_TYPE: JSON,
**self.headers,
}
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID if available."""
headers = base_headers.copy()
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
return headers
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request."""
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialized notification."""
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
def _maybe_extract_session_id_from_response(
self,
response: httpx.Response,
):
"""Extract and store session ID from response headers."""
new_session_id = response.headers.get(MCP_SESSION_ID)
if new_session_id:
self.session_id = new_session_id
logger.info("Received session ID: %s", self.session_id)
def _handle_sse_event(
self,
sse: ServerSentEvent,
server_to_client_queue: ServerToClientQueue,
original_request_id: RequestId | None = None,
resumption_callback: Callable[[str], None] | None = None,
) -> bool:
"""Handle an SSE event, returning True if the response is complete."""
if sse.event == "message":
# ping event send by server will be recognized as a message event with empty data by httpx-sse's SSEDecoder
if not sse.data.strip():
return False
try:
message = JSONRPCMessage.model_validate_json(sse.data)
logger.debug("SSE message: %s", message)
# If this is a response and we have original_request_id, replace it
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
message.root.id = original_request_id
session_message = SessionMessage(message)
# Put message in queue that goes to client
server_to_client_queue.put(session_message)
# Call resumption token callback if we have an ID
if sse.id and resumption_callback:
resumption_callback(sse.id)
# If this is a response or error return True indicating completion
# Otherwise, return False to continue listening
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
except Exception as exc:
# Put exception in queue that goes to client
server_to_client_queue.put(exc)
return False
elif sse.event == "ping":
logger.debug("Received ping event")
return False
else:
logger.warning("Unknown SSE event: %s", sse.event)
return False
def handle_get_stream(
self,
client: httpx.Client,
server_to_client_queue: ServerToClientQueue,
):
"""Handle GET stream for server-initiated messages."""
try:
if not self.session_id:
return
headers = self._update_headers_with_session(self.request_headers)
with ssrf_proxy_sse_connect(
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
client=client,
method="GET",
) as event_source:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")
for sse in event_source.iter_sse():
self._handle_sse_event(sse, server_to_client_queue)
except Exception as exc:
logger.debug("GET stream error (non-fatal): %s", exc)
def _handle_resumption_request(self, ctx: RequestContext):
"""Handle a resumption request using GET with SSE."""
headers = self._update_headers_with_session(ctx.headers)
if ctx.metadata and ctx.metadata.resumption_token:
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
else:
raise ResumptionError("Resumption request requires a resumption token")
# Extract original request ID to map responses
original_request_id = None
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
original_request_id = ctx.session_message.message.root.id
with ssrf_proxy_sse_connect(
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
client=ctx.client,
method="GET",
) as event_source:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")
for sse in event_source.iter_sse():
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break
def _handle_post_request(self, ctx: RequestContext):
"""Handle a POST request with response processing."""
headers = self._update_headers_with_session(ctx.headers)
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)
with ctx.client.stream(
"POST",
self.url,
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
headers=headers,
) as response:
if response.status_code == 202:
logger.debug("Received 202 Accepted")
return
if response.status_code == 204:
logger.debug("Received 204 No Content")
return
if response.status_code == 404:
if isinstance(message.root, JSONRPCRequest):
self._send_session_terminated_error(
ctx.server_to_client_queue,
message.root.id,
)
return
response.raise_for_status()
if is_initialization:
self._maybe_extract_session_id_from_response(response)
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
if content_type.startswith(JSON):
self._handle_json_response(response, ctx.server_to_client_queue)
elif content_type.startswith(SSE):
self._handle_sse_response(response, ctx)
else:
self._handle_unexpected_content_type(
content_type,
ctx.server_to_client_queue,
)
def _handle_json_response(
self,
response: httpx.Response,
server_to_client_queue: ServerToClientQueue,
):
"""Handle JSON response from the server."""
try:
content = response.read()
message = JSONRPCMessage.model_validate_json(content)
session_message = SessionMessage(message)
server_to_client_queue.put(session_message)
except Exception as exc:
server_to_client_queue.put(exc)
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
"""Handle SSE response from the server."""
try:
event_source = EventSource(response)
for sse in event_source.iter_sse():
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
)
if is_complete:
break
except Exception as e:
ctx.server_to_client_queue.put(e)
def _handle_unexpected_content_type(
self,
content_type: str,
server_to_client_queue: ServerToClientQueue,
):
"""Handle unexpected content type in response."""
error_msg = f"Unexpected content type: {content_type}"
logger.error(error_msg)
server_to_client_queue.put(ValueError(error_msg))
def _send_session_terminated_error(
self,
server_to_client_queue: ServerToClientQueue,
request_id: RequestId,
):
"""Send a session terminated error response."""
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=ErrorData(code=32600, message="Session terminated by server"),
)
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
server_to_client_queue.put(session_message)
def post_writer(
self,
client: httpx.Client,
client_to_server_queue: ClientToServerQueue,
server_to_client_queue: ServerToClientQueue,
start_get_stream: Callable[[], None],
):
"""Handle writing requests to the server.
This method processes messages from the client_to_server_queue and sends them to the server.
Responses are written to the server_to_client_queue.
"""
while True:
try:
# Read message from client queue with timeout to check stop_event periodically
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if session_message is None:
break
message = session_message.message
metadata = (
session_message.metadata if isinstance(session_message.metadata, ClientMessageMetadata) else None
)
# Check if this is a resumption request
is_resumption = bool(metadata and metadata.resumption_token)
logger.debug("Sending client message: %s", message)
# Handle initialized notification
if self._is_initialized_notification(message):
start_get_stream()
ctx = RequestContext(
client=client,
headers=self.request_headers,
session_id=self.session_id,
session_message=session_message,
metadata=metadata,
server_to_client_queue=server_to_client_queue, # Queue to write responses to client
sse_read_timeout=self.sse_read_timeout,
)
if is_resumption:
self._handle_resumption_request(ctx)
else:
self._handle_post_request(ctx)
except queue.Empty:
continue
except Exception as exc:
server_to_client_queue.put(exc)
def terminate_session(self, client: httpx.Client):
"""Terminate the session by sending a DELETE request."""
if not self.session_id:
return
try:
headers = self._update_headers_with_session(self.request_headers)
response = client.delete(self.url, headers=headers)
if response.status_code == 405:
logger.debug("Server does not allow session termination")
elif response.status_code != 200:
logger.warning("Session termination failed: %s", response.status_code)
except Exception as exc:
logger.warning("Session termination failed: %s", exc)
def get_session_id(self) -> str | None:
"""Get the current session ID."""
return self.session_id
@contextmanager
def streamablehttp_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
terminate_on_close: bool = True,
) -> Generator[
tuple[
ServerToClientQueue, # Queue for receiving messages FROM server
ClientToServerQueue, # Queue for sending messages TO server
GetSessionIdCallback,
],
None,
None,
]:
"""
Client transport for StreamableHTTP.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
Yields:
Tuple containing:
- server_to_client_queue: Queue for reading messages FROM the server
- client_to_server_queue: Queue for sending messages TO the server
- get_session_id_callback: Function to retrieve the current session ID
"""
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
# Create queues with clear directional meaning
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
executor = ThreadPoolExecutor(max_workers=2)
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream():
"""Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
# Start the post_writer worker thread
executor.submit(
transport.post_writer,
client,
client_to_server_queue, # Queue for messages FROM client TO server
server_to_client_queue, # Queue for messages FROM server TO client
start_get_stream,
)
try:
yield (
server_to_client_queue, # Queue for receiving messages FROM server
client_to_server_queue, # Queue for sending messages TO server
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
transport.terminate_session(client)
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():
client_to_server_queue.get_nowait()
except queue.Empty:
pass
client_to_server_queue.put(None)
server_to_client_queue.put(None)
# Shutdown executor without waiting to prevent hanging
executor.shutdown(wait=False)

View File

@@ -0,0 +1,60 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, Generic, TypeVar
from pydantic import BaseModel
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
@dataclass
class RequestContext(Generic[SessionT, LifespanContextT]):
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
class AuthActionType(StrEnum):
"""Types of actions that can be performed during auth flow."""
SAVE_CLIENT_INFO = "save_client_info"
SAVE_TOKENS = "save_tokens"
SAVE_CODE_VERIFIER = "save_code_verifier"
START_AUTHORIZATION = "start_authorization"
SUCCESS = "success"
class AuthAction(BaseModel):
"""Represents an action that needs to be performed as a result of auth flow."""
action_type: AuthActionType
data: dict[str, Any]
provider_id: str | None = None
tenant_id: str | None = None
class AuthResult(BaseModel):
"""Result of auth function containing actions to be performed and response data."""
actions: list[AuthAction]
response: dict[str, str]
class OAuthCallbackState(BaseModel):
"""State data stored in Redis during OAuth callback flow."""
provider_id: str
tenant_id: str
server_url: str
metadata: OAuthMetadata | None = None
client_information: OAuthClientInformation
code_verifier: str
redirect_uri: str

View File

@@ -0,0 +1,63 @@
import re
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import httpx
class MCPError(Exception):
pass
class MCPConnectionError(MCPError):
pass
class MCPAuthError(MCPConnectionError):
def __init__(
self,
message: str | None = None,
response: "httpx.Response | None" = None,
www_authenticate_header: str | None = None,
):
"""
MCP Authentication Error.
Args:
message: Error message
response: HTTP response object (will extract WWW-Authenticate header if provided)
www_authenticate_header: Pre-extracted WWW-Authenticate header value
"""
super().__init__(message or "Authentication failed")
# Extract OAuth metadata hints from WWW-Authenticate header
if response is not None:
www_authenticate_header = response.headers.get("WWW-Authenticate")
self.resource_metadata_url: str | None = None
self.scope_hint: str | None = None
if www_authenticate_header:
self.resource_metadata_url = self._extract_field(www_authenticate_header, "resource_metadata")
self.scope_hint = self._extract_field(www_authenticate_header, "scope")
@staticmethod
def _extract_field(www_auth: str, field_name: str) -> str | None:
"""Extract a specific field from the WWW-Authenticate header."""
# Pattern to match field="value" or field=value
pattern = rf'{field_name}="([^"]*)"'
match = re.search(pattern, www_auth)
if match:
return match.group(1)
# Try without quotes
pattern = rf"{field_name}=([^\s,]+)"
match = re.search(pattern, www_auth)
if match:
return match.group(1)
return None
class MCPRefreshTokenError(MCPError):
pass

View File

@@ -0,0 +1,115 @@
import logging
from collections.abc import Callable
from contextlib import AbstractContextManager, ExitStack
from types import TracebackType
from typing import Any
from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client
from core.mcp.client.streamable_client import streamablehttp_client
from core.mcp.error import MCPConnectionError
from core.mcp.session.client_session import ClientSession
from core.mcp.types import CallToolResult, Tool
logger = logging.getLogger(__name__)
class MCPClient:
def __init__(
self,
server_url: str,
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
):
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
# Initialize session and client objects
self._session: ClientSession | None = None
self._exit_stack = ExitStack()
self._initialized = False
def __enter__(self):
self._initialize()
self._initialized = True
return self
def __exit__(self, exc_type: type | None, exc_value: BaseException | None, traceback: TracebackType | None):
self.cleanup()
def _initialize(
self,
):
"""Initialize the client with fallback to SSE if streamable connection fails"""
connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
"mcp": streamablehttp_client,
"sse": sse_client,
}
parsed_url = urlparse(self.server_url)
path = parsed_url.path or ""
method_name = path.rstrip("/").split("/")[-1] if path else ""
if method_name in connection_methods:
client_factory = connection_methods[method_name]
self.connect_server(client_factory, method_name)
else:
try:
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
self.connect_server(sse_client, "sse")
except MCPConnectionError:
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp")
def connect_server(self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str) -> None:
"""
Connect to the MCP server using streamable http or sse.
Default to streamable http.
Args:
client_factory: The client factory to use(streamablehttp_client or sse_client).
method_name: The method name to use(mcp or sse).
"""
streams_context = client_factory(
url=self.server_url,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self._exit_stack.enter_context(streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self._exit_stack.enter_context(streams_context)
session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(session_context)
self._session.initialize()
def list_tools(self) -> list[Tool]:
"""List available tools from the MCP server"""
if not self._session:
raise ValueError("Session not initialized.")
response = self._session.list_tools()
return response.tools
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""Call a tool"""
if not self._session:
raise ValueError("Session not initialized.")
return self._session.call_tool(tool_name, tool_args)
def cleanup(self):
"""Clean up resources"""
try:
# ExitStack will handle proper cleanup of all managed context managers
self._exit_stack.close()
except Exception as e:
logger.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")
finally:
self._session = None
self._initialized = False

View File

@@ -0,0 +1,262 @@
import json
import logging
from collections.abc import Mapping
from typing import Any, cast
from configs import dify_config
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types as mcp_types
from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService
logger = logging.getLogger(__name__)
def handle_mcp_request(
app: App,
request: mcp_types.ClientRequest,
user_input_form: list[VariableEntity],
mcp_server: AppMCPServer,
end_user: EndUser | None = None,
request_id: int | str = 1,
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError:
"""
Handle MCP request and return JSON-RPC response
Args:
app: The Dify app instance
request: The JSON-RPC request message
user_input_form: List of variable entities for the app
mcp_server: The MCP server configuration
end_user: Optional end user
request_id: The request ID
Returns:
JSON-RPC response or error
"""
request_type = type(request.root)
request_root = request.root
def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse:
"""Create success response with business result data"""
return mcp_types.JSONRPCResponse(
jsonrpc="2.0",
id=request_id,
result=result_data.model_dump(by_alias=True, mode="json", exclude_none=True),
)
def create_error_response(code: int, message: str) -> mcp_types.JSONRPCError:
"""Create error response with error code and message"""
from core.mcp.types import ErrorData
error_data = ErrorData(code=code, message=message)
return mcp_types.JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=error_data,
)
try:
# Dispatch request to appropriate handler based on instance type
if isinstance(request_root, mcp_types.InitializeRequest):
return create_success_response(handle_initialize(mcp_server.description))
elif isinstance(request_root, mcp_types.ListToolsRequest):
return create_success_response(
handle_list_tools(
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
)
)
elif isinstance(request_root, mcp_types.CallToolRequest):
return create_success_response(handle_call_tool(app, request, user_input_form, end_user))
elif isinstance(request_root, mcp_types.PingRequest):
return create_success_response(handle_ping())
else:
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
except ValueError as e:
logger.exception("Invalid params")
return create_error_response(mcp_types.INVALID_PARAMS, str(e))
except Exception as e:
logger.exception("Internal server error")
return create_error_response(mcp_types.INTERNAL_ERROR, "Internal server error: " + str(e))
def handle_ping() -> mcp_types.EmptyResult:
"""Handle ping request"""
return mcp_types.EmptyResult()
def handle_initialize(description: str) -> mcp_types.InitializeResult:
"""Handle initialize request"""
capabilities = mcp_types.ServerCapabilities(
tools=mcp_types.ToolsCapability(listChanged=False),
)
return mcp_types.InitializeResult(
protocolVersion=mcp_types.SERVER_LATEST_PROTOCOL_VERSION,
capabilities=capabilities,
serverInfo=mcp_types.Implementation(name="Dify", version=dify_config.project.version),
instructions=description,
)
def handle_list_tools(
app_name: str,
app_mode: str,
user_input_form: list[VariableEntity],
description: str,
parameters_dict: dict[str, str],
) -> mcp_types.ListToolsResult:
"""Handle list tools request"""
parameter_schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)
return mcp_types.ListToolsResult(
tools=[
mcp_types.Tool(
name=app_name,
description=description,
inputSchema=parameter_schema,
)
],
)
def handle_call_tool(
app: App,
request: mcp_types.ClientRequest,
user_input_form: list[VariableEntity],
end_user: EndUser | None,
) -> mcp_types.CallToolResult:
"""Handle call tool request"""
request_obj = cast(mcp_types.CallToolRequest, request.root)
args = prepare_tool_arguments(app, request_obj.params.arguments or {})
if not end_user:
raise ValueError("End user not found")
response = AppGenerateService.generate(
app,
end_user,
args,
InvokeFrom.SERVICE_API,
streaming=app.mode == AppMode.AGENT_CHAT,
)
answer = extract_answer_from_response(app, response)
return mcp_types.CallToolResult(content=[mcp_types.TextContent(text=answer, type="text")])
def build_parameter_schema(
app_mode: str,
user_input_form: list[VariableEntity],
parameters_dict: dict[str, str],
) -> dict[str, Any]:
"""Build parameter schema for the tool"""
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
if app_mode in {AppMode.COMPLETION, AppMode.WORKFLOW}:
return {
"type": "object",
"properties": parameters,
"required": required,
}
return {
"type": "object",
"properties": {
"query": {"type": "string", "description": "User Input/Question content"},
**parameters,
},
"required": ["query", *required],
}
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
"""Prepare arguments based on app mode"""
if app.mode == AppMode.WORKFLOW:
return {"inputs": arguments}
elif app.mode == AppMode.COMPLETION:
return {"query": "", "inputs": arguments}
else:
# Chat modes - create a copy to avoid modifying original dict
args_copy = arguments.copy()
query = args_copy.pop("query", "")
return {"query": query, "inputs": args_copy}
def extract_answer_from_response(app: App, response: Any) -> str:
"""Extract answer from app generate response"""
answer = ""
if isinstance(response, RateLimitGenerator):
answer = process_streaming_response(response)
elif isinstance(response, Mapping):
answer = process_mapping_response(app, response)
else:
logger.warning("Unexpected response type: %s", type(response))
return answer
def process_streaming_response(response: RateLimitGenerator) -> str:
"""Process streaming response for agent chat mode"""
answer = ""
for item in response.generator:
if isinstance(item, str) and item.startswith("data: "):
try:
json_str = item[6:].strip()
parsed_data = json.loads(json_str)
if parsed_data.get("event") == "agent_thought":
answer += parsed_data.get("thought", "")
except json.JSONDecodeError:
continue
return answer
def process_mapping_response(app: App, response: Mapping) -> str:
"""Process mapping response based on app mode"""
if app.mode in {
AppMode.ADVANCED_CHAT,
AppMode.COMPLETION,
AppMode.CHAT,
AppMode.AGENT_CHAT,
}:
return response.get("answer", "")
elif app.mode == AppMode.WORKFLOW:
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode: " + str(app.mode))
def convert_input_form_to_parameters(
user_input_form: list[VariableEntity],
parameters_dict: dict[str, str],
) -> tuple[dict[str, dict[str, Any]], list[str]]:
"""Convert user input form to parameter schema"""
parameters: dict[str, dict[str, Any]] = {}
required = []
for item in user_input_form:
if item.type in (
VariableEntityType.FILE,
VariableEntityType.FILE_LIST,
VariableEntityType.EXTERNAL_DATA_TOOL,
):
continue
parameters[item.variable] = {}
if item.required:
required.append(item.variable)
# if the workflow republished, the parameters not changed
# we should not raise error here
description = parameters_dict.get(item.variable, "")
parameters[item.variable]["description"] = description
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"
elif item.type == VariableEntityType.SELECT:
parameters[item.variable]["type"] = "string"
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "number"
return parameters, required

View File

@@ -0,0 +1,423 @@
import logging
import queue
from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, Self, TypeVar
from httpx import HTTPStatusError
from pydantic import BaseModel
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.types import (
CancelledNotification,
ClientNotification,
ClientRequest,
ClientResult,
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
MessageMetadata,
RequestId,
RequestParams,
ServerMessageMetadata,
ServerNotification,
ServerRequest,
ServerResult,
SessionMessage,
)
logger = logging.getLogger(__name__)
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.
This class MUST be used as a context manager to ensure proper cleanup and
cancellation handling:
Example:
with request_responder as resp:
resp.respond(result)
The context manager ensures:
1. Proper cancellation scope setup and cleanup
2. Request completion tracking
3. Cleanup of in-flight requests
"""
request: ReceiveRequestT
_session: Any
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
def __init__(
self,
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: """BaseSession[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT
]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
):
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self._session = session
self.completed = False
self._on_complete = on_complete
self._entered = False # Track if we're in a context manager
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
"""Enter the context manager, enabling request cancellation tracking."""
self._entered = True
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
):
"""Exit the context manager, performing cleanup and notifying completion."""
try:
if self.completed:
self._on_complete(self)
finally:
self._entered = False
def respond(self, response: SendResultT | ErrorData):
"""Send a response for this request.
Must be called within a context manager block.
Raises:
RuntimeError: If not used within a context manager
AssertionError: If request was already responded to
"""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
assert not self.completed, "Request already responded to"
self.completed = True
self._session._send_response(request_id=self.request_id, response=response)
def cancel(self):
"""Cancel this request and mark it as completed."""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
self.completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
self._session._send_response(
request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None),
)
class BaseSession(
Generic[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
],
):
"""
Implements an MCP "session" on top of read/write streams, including features
like request/response linking, notifications, and progress.
This class is a context manager that automatically starts processing
messages when entered.
"""
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError]]
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_receive_request_type: type[ReceiveRequestT]
_receive_notification_type: type[ReceiveNotificationT]
def __init__(
self,
read_stream: queue.Queue,
write_stream: queue.Queue,
receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
read_timeout_seconds: timedelta | None = None,
):
self._read_stream = read_stream
self._write_stream = write_stream
self._response_streams = {}
self._request_id = 0
self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
# Initialize executor and future to None for proper cleanup checks
self._executor: ThreadPoolExecutor | None = None
self._receiver_future: Future | None = None
def __enter__(self) -> Self:
# The thread pool is dedicated to running `_receive_loop`. Setting `max_workers` to 1
# ensures no unnecessary threads are created.
self._executor = ThreadPoolExecutor(max_workers=1)
self._receiver_future = self._executor.submit(self._receive_loop)
return self
def check_receiver_status(self):
"""`check_receiver_status` ensures that any exceptions raised during the
execution of `_receive_loop` are retrieved and propagated."""
if self._receiver_future and self._receiver_future.done():
self._receiver_future.result()
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
):
self._read_stream.put(None)
self._write_stream.put(None)
# Wait for the receiver loop to finish
if self._receiver_future:
try:
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
except TimeoutError:
# If the receiver loop is still running after timeout, we'll force shutdown
# Cancel the future to interrupt the receiver loop
self._receiver_future.cancel()
# Shutdown the executor
if self._executor:
# Use non-blocking shutdown to prevent hanging
# The receiver thread should have already exited due to the None message in the queue
self._executor.shutdown(wait=False)
def send_request(
self,
request: SendRequestT,
result_type: type[ReceiveResultT],
request_read_timeout_seconds: timedelta | None = None,
metadata: MessageMetadata | None = None,
) -> ReceiveResultT:
"""
Sends a request and wait for a response. Raises an McpError if the
response contains an error. If a request read timeout is provided, it
will take precedence over the session read timeout.
Do not use this method to emit notifications! Use send_notification()
instead.
"""
self.check_receiver_status()
request_id = self._request_id
self._request_id = request_id + 1
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError] = queue.Queue()
self._response_streams[request_id] = response_queue
try:
jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
)
self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
timeout = DEFAULT_RESPONSE_READ_TIMEOUT
if request_read_timeout_seconds is not None:
timeout = float(request_read_timeout_seconds.total_seconds())
elif self._session_read_timeout_seconds is not None:
timeout = float(self._session_read_timeout_seconds.total_seconds())
while True:
try:
response_or_error = response_queue.get(timeout=timeout)
break
except queue.Empty:
self.check_receiver_status()
continue
if response_or_error is None:
raise MCPConnectionError(
ErrorData(
code=500,
message="No response received",
)
)
elif isinstance(response_or_error, HTTPStatusError):
# HTTPStatusError from streamable_client with preserved response object
if response_or_error.response.status_code == 401:
raise MCPAuthError(response=response_or_error.response)
else:
raise MCPConnectionError(
ErrorData(code=response_or_error.response.status_code, message=str(response_or_error))
)
elif isinstance(response_or_error, JSONRPCError):
if response_or_error.error.code == 401:
raise MCPAuthError(message=response_or_error.error.message)
else:
raise MCPConnectionError(
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
)
else:
return result_type.model_validate(response_or_error.result)
finally:
self._response_streams.pop(request_id, None)
def send_notification(
self,
notification: SendNotificationT,
related_request_id: RequestId | None = None,
):
"""
Emits a notification, which is a one-way message that does not expect
a response.
"""
self.check_receiver_status()
# Some transport implementations may need to set the related_request_id
# to attribute to the notifications to the request that triggered them.
jsonrpc_notification = JSONRPCNotification(
jsonrpc="2.0",
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
)
session_message = SessionMessage(
message=JSONRPCMessage(jsonrpc_notification),
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
)
self._write_stream.put(session_message)
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
self._write_stream.put(session_message)
else:
jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0",
id=request_id,
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
self._write_stream.put(session_message)
def _receive_loop(self):
"""
Main message processing loop.
In a real synchronous implementation, this would likely run in a separate thread.
"""
while True:
try:
# Attempt to receive a message (this would be blocking in a synchronous context)
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
if message is None:
break
if isinstance(message, HTTPStatusError):
response_queue = self._response_streams.get(self._request_id - 1)
if response_queue is not None:
# For 401 errors, pass the HTTPStatusError directly to preserve response object
if message.response.status_code == 401:
response_queue.put(message)
else:
response_queue.put(
JSONRPCError(
jsonrpc="2.0",
id=self._request_id - 1,
error=ErrorData(code=message.response.status_code, message=message.args[0]),
)
)
else:
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
elif isinstance(message, Exception):
self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
)
self._in_flight[responder.request_id] = responder
self._received_request(responder)
if not responder.completed:
self._handle_incoming(responder)
elif isinstance(message.message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
self._in_flight[cancelled_id].cancel()
else:
self._received_notification(notification)
self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)
else: # Response or error
response_queue = self._response_streams.get(message.message.root.id)
if response_queue is not None:
response_queue.put(message.message.root)
else:
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
except queue.Empty:
continue
except Exception:
logger.exception("Error in message processing loop")
raise
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]):
"""
Can be overridden by subclasses to handle a request without needing to
listen on the message stream.
If the request is responded to within this method, it will not be
forwarded on to the message stream.
"""
def _received_notification(self, notification: ReceiveNotificationT):
"""
Can be overridden by subclasses to handle a notification without needing
to listen on the message stream.
"""
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
"""
Sends a progress notification for a request that is currently being
processed.
"""
def _handle_incoming(
self,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
):
"""A generic handler for incoming messages. Overwritten by subclasses."""

View File

@@ -0,0 +1,368 @@
import queue
from datetime import timedelta
from typing import Any, Protocol
from pydantic import AnyUrl, TypeAdapter
from configs import dify_config
from core.mcp import types
from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
from core.mcp.session.base_session import BaseSession, RequestResponder
DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.project.version)
class SamplingFnT(Protocol):
def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData: ...
class ListRootsFnT(Protocol):
def __call__(self, context: RequestContext["ClientSession", Any]) -> types.ListRootsResult | types.ErrorData: ...
class LoggingFnT(Protocol):
def __call__(
self,
params: types.LoggingMessageNotificationParams,
): ...
class MessageHandlerFnT(Protocol):
def __call__(
self,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
): ...
def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
):
if isinstance(message, Exception):
raise ValueError(str(message))
elif isinstance(message, (types.ServerNotification | RequestResponder)):
pass
def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Sampling not supported",
)
def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="List roots not supported",
)
def _default_logging_callback(
params: types.LoggingMessageNotificationParams,
):
pass
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
class ClientSession(
BaseSession[
types.ClientRequest,
types.ClientNotification,
types.ClientResult,
types.ServerRequest,
types.ServerNotification,
]
):
def __init__(
self,
read_stream: queue.Queue,
write_stream: queue.Queue,
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
):
super().__init__(
read_stream,
write_stream,
types.ServerRequest,
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
def initialize(self) -> types.InitializeResult:
# Only set capabilities if non-default callbacks are provided
# This prevents servers from attempting callbacks when we don't actually support them
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
roots = (
types.RootsCapability(
# Only enable listChanged if we have a custom callback
listChanged=True,
)
if self._list_roots_callback is not _default_list_roots_callback
else None
)
result = self.send_request(
types.ClientRequest(
types.InitializeRequest(
method="initialize",
params=types.InitializeRequestParams(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=sampling,
experimental=None,
roots=roots,
),
clientInfo=self._client_info,
),
)
),
types.InitializeResult,
)
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
self.send_notification(
types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
)
return result
def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
return self.send_request(
types.ClientRequest(
types.PingRequest(
method="ping",
)
),
types.EmptyResult,
)
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
"""Send a progress notification."""
self.send_notification(
types.ClientNotification(
types.ProgressNotification(
method="notifications/progress",
params=types.ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
),
),
)
)
def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
"""Send a logging/setLevel request."""
return self.send_request(
types.ClientRequest(
types.SetLevelRequest(
method="logging/setLevel",
params=types.SetLevelRequestParams(level=level),
)
),
types.EmptyResult,
)
def list_resources(self) -> types.ListResourcesResult:
"""Send a resources/list request."""
return self.send_request(
types.ClientRequest(
types.ListResourcesRequest(
method="resources/list",
)
),
types.ListResourcesResult,
)
def list_resource_templates(self) -> types.ListResourceTemplatesResult:
"""Send a resources/templates/list request."""
return self.send_request(
types.ClientRequest(
types.ListResourceTemplatesRequest(
method="resources/templates/list",
)
),
types.ListResourceTemplatesResult,
)
def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
"""Send a resources/read request."""
return self.send_request(
types.ClientRequest(
types.ReadResourceRequest(
method="resources/read",
params=types.ReadResourceRequestParams(uri=uri),
)
),
types.ReadResourceResult,
)
def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/subscribe request."""
return self.send_request(
types.ClientRequest(
types.SubscribeRequest(
method="resources/subscribe",
params=types.SubscribeRequestParams(uri=uri),
)
),
types.EmptyResult,
)
def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/unsubscribe request."""
return self.send_request(
types.ClientRequest(
types.UnsubscribeRequest(
method="resources/unsubscribe",
params=types.UnsubscribeRequestParams(uri=uri),
)
),
types.EmptyResult,
)
def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
) -> types.CallToolResult:
"""Send a tools/call request."""
return self.send_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(name=name, arguments=arguments),
)
),
types.CallToolResult,
request_read_timeout_seconds=read_timeout_seconds,
)
def list_prompts(self) -> types.ListPromptsResult:
"""Send a prompts/list request."""
return self.send_request(
types.ClientRequest(
types.ListPromptsRequest(
method="prompts/list",
)
),
types.ListPromptsResult,
)
def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
"""Send a prompts/get request."""
return self.send_request(
types.ClientRequest(
types.GetPromptRequest(
method="prompts/get",
params=types.GetPromptRequestParams(name=name, arguments=arguments),
)
),
types.GetPromptResult,
)
def complete(
self,
ref: types.ResourceTemplateReference | types.PromptReference,
argument: dict[str, str],
) -> types.CompleteResult:
"""Send a completion/complete request."""
return self.send_request(
types.ClientRequest(
types.CompleteRequest(
method="completion/complete",
params=types.CompleteRequestParams(
ref=ref,
argument=types.CompletionArgument.model_validate(argument),
),
)
),
types.CompleteResult,
)
def list_tools(self) -> types.ListToolsResult:
"""Send a tools/list request."""
return self.send_request(
types.ClientRequest(
types.ListToolsRequest(
method="tools/list",
)
),
types.ListToolsResult,
)
def send_roots_list_changed(self):
"""Send a roots/list_changed notification."""
self.send_notification(
types.ClientNotification(
types.RootsListChangedNotification(
method="notifications/roots/list_changed",
)
)
)
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]):
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
meta=responder.request_meta,
session=self,
lifespan_context=None,
)
match responder.request.root:
case types.CreateMessageRequest(params=params):
with responder:
response = self._sampling_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
responder.respond(client_response)
case types.ListRootsRequest():
with responder:
list_roots_response = self._list_roots_callback(ctx)
client_response = ClientResponse.validate_python(list_roots_response)
responder.respond(client_response)
case types.PingRequest():
with responder:
return responder.respond(types.ClientResult(root=types.EmptyResult()))
def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
):
"""Handle incoming messages by forwarding to the message handler."""
self._message_handler(req)
def _received_notification(self, notification: types.ServerNotification):
"""Handle notifications from the server."""
# Process specific notification types
match notification.root:
case types.LoggingMessageNotification(params=params):
self._logging_callback(params)
case _:
pass

1342
dify/api/core/mcp/types.py Normal file

File diff suppressed because it is too large Load Diff

142
dify/api/core/mcp/utils.py Normal file
View File

@@ -0,0 +1,142 @@
import json
from collections.abc import Generator
from contextlib import AbstractContextManager
import httpx
import httpx_sse
from httpx_sse import connect_sse
from configs import dify_config
from core.mcp.types import ErrorData, JSONRPCError
from core.model_runtime.utils.encoders import jsonable_encoder
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
STATUS_FORCELIST = [429, 500, 502, 503, 504]
def create_ssrf_proxy_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
Args:
headers: Optional headers to include in the client
timeout: Optional timeout configuration
Returns:
Configured httpx.Client with proxy settings
"""
if dify_config.SSRF_PROXY_ALL_URL:
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
proxy=dify_config.SSRF_PROXY_ALL_URL,
)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
),
}
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
mounts=proxy_mounts,
)
else:
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
)
def ssrf_proxy_sse_connect(url: str, **kwargs) -> AbstractContextManager[httpx_sse.EventSource]:
"""Connect to SSE endpoint with SSRF proxy protection.
This function creates an SSE connection using the configured proxy settings
to prevent SSRF attacks when connecting to external endpoints. It returns
a context manager that yields an EventSource object for SSE streaming.
The function handles HTTP client creation and cleanup automatically, but
also accepts a pre-configured client via kwargs.
Args:
url (str): The SSE endpoint URL to connect to
**kwargs: Additional arguments passed to the SSE connection, including:
- client (httpx.Client, optional): Pre-configured HTTP client.
If not provided, one will be created with SSRF protection.
- method (str, optional): HTTP method to use, defaults to "GET"
- headers (dict, optional): HTTP headers to include in the request
- timeout (httpx.Timeout, optional): Timeout configuration for the connection
Returns:
AbstractContextManager[httpx_sse.EventSource]: A context manager that yields an EventSource
object for SSE streaming. The EventSource provides access to server-sent events.
Example:
```python
with ssrf_proxy_sse_connect(url, headers=headers) as event_source:
for sse in event_source.iter_sse():
print(sse.event, sse.data)
```
Note:
If a client is not provided in kwargs, one will be automatically created
with SSRF protection based on the application's configuration. If an
exception occurs during connection, any automatically created client
will be cleaned up automatically.
"""
# Extract client if provided, otherwise create one
client = kwargs.pop("client", None)
if client is None:
# Create client with SSRF proxy configuration
timeout = kwargs.pop(
"timeout",
httpx.Timeout(
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
),
)
headers = kwargs.pop("headers", {})
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
client_provided = False
else:
client_provided = True
# Extract method if provided, default to GET
method = kwargs.pop("method", "GET")
try:
return connect_sse(client, method, url, **kwargs)
except Exception:
# If we created the client, we need to clean it up on error
if not client_provided:
client.close()
raise
def create_mcp_error_response(
request_id: int | str | None, code: int, message: str, data=None
) -> Generator[bytes, None, None]:
"""Create MCP error response"""
error_data = ErrorData(code=code, message=message, data=data)
json_response = JSONRPCError(
jsonrpc="2.0",
id=request_id or 1,
error=error_data,
)
json_data = json.dumps(jsonable_encoder(json_response))
sse_content = json_data.encode()
yield sse_content