dify
This commit is contained in:
0
dify/api/services/auth/__init__.py
Normal file
0
dify/api/services/auth/__init__.py
Normal file
10
dify/api/services/auth/api_key_auth_base.py
Normal file
10
dify/api/services/auth/api_key_auth_base.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ApiKeyAuthBase(ABC):
|
||||
def __init__(self, credentials: dict):
|
||||
self.credentials = credentials
|
||||
|
||||
@abstractmethod
|
||||
def validate_credentials(self):
|
||||
raise NotImplementedError
|
||||
29
dify/api/services/auth/api_key_auth_factory.py
Normal file
29
dify/api/services/auth/api_key_auth_factory.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.auth_type import AuthType
|
||||
|
||||
|
||||
class ApiKeyAuthFactory:
|
||||
def __init__(self, provider: str, credentials: dict):
|
||||
auth_factory = self.get_apikey_auth_factory(provider)
|
||||
self.auth = auth_factory(credentials)
|
||||
|
||||
def validate_credentials(self):
|
||||
return self.auth.validate_credentials()
|
||||
|
||||
@staticmethod
|
||||
def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]:
|
||||
match provider:
|
||||
case AuthType.FIRECRAWL:
|
||||
from services.auth.firecrawl.firecrawl import FirecrawlAuth
|
||||
|
||||
return FirecrawlAuth
|
||||
case AuthType.WATERCRAWL:
|
||||
from services.auth.watercrawl.watercrawl import WatercrawlAuth
|
||||
|
||||
return WatercrawlAuth
|
||||
case AuthType.JINA:
|
||||
from services.auth.jina.jina import JinaAuth
|
||||
|
||||
return JinaAuth
|
||||
case _:
|
||||
raise ValueError("Invalid provider")
|
||||
77
dify/api/services/auth/api_key_auth_service.py
Normal file
77
dify/api/services/auth/api_key_auth_service.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import json
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
|
||||
|
||||
class ApiKeyAuthService:
|
||||
@staticmethod
|
||||
def get_provider_auth_list(tenant_id: str):
|
||||
data_source_api_key_bindings = db.session.scalars(
|
||||
select(DataSourceApiKeyAuthBinding).where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
)
|
||||
).all()
|
||||
return data_source_api_key_bindings
|
||||
|
||||
@staticmethod
|
||||
def create_provider_auth(tenant_id: str, args: dict):
|
||||
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
|
||||
if auth_result:
|
||||
# Encrypt the api key
|
||||
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
|
||||
args["credentials"]["config"]["api_key"] = api_key
|
||||
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant_id, category=args["category"], provider=args["provider"]
|
||||
)
|
||||
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: str):
|
||||
data_source_api_key_bindings = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.category == category,
|
||||
DataSourceApiKeyAuthBinding.provider == provider,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not data_source_api_key_bindings:
|
||||
return None
|
||||
if not data_source_api_key_bindings.credentials:
|
||||
return None
|
||||
credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str):
|
||||
data_source_api_key_binding = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
|
||||
.first()
|
||||
)
|
||||
if data_source_api_key_binding:
|
||||
db.session.delete(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def validate_api_key_auth_args(cls, args):
|
||||
if "category" not in args or not args["category"]:
|
||||
raise ValueError("category is required")
|
||||
if "provider" not in args or not args["provider"]:
|
||||
raise ValueError("provider is required")
|
||||
if "credentials" not in args or not args["credentials"]:
|
||||
raise ValueError("credentials is required")
|
||||
if not isinstance(args["credentials"], dict):
|
||||
raise ValueError("credentials must be a dictionary")
|
||||
if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
|
||||
raise ValueError("auth_type is required")
|
||||
7
dify/api/services/auth/auth_type.py
Normal file
7
dify/api/services/auth/auth_type.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class AuthType(StrEnum):
|
||||
FIRECRAWL = "firecrawl"
|
||||
WATERCRAWL = "watercrawl"
|
||||
JINA = "jinareader"
|
||||
0
dify/api/services/auth/firecrawl/__init__.py
Normal file
0
dify/api/services/auth/firecrawl/__init__.py
Normal file
49
dify/api/services/auth/firecrawl/firecrawl.py
Normal file
49
dify/api/services/auth/firecrawl/firecrawl.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
||||
class FirecrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer")
|
||||
self.api_key = credentials.get("config", {}).get("api_key", None)
|
||||
self.base_url = credentials.get("config", {}).get("base_url", "https://api.firecrawl.dev")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def validate_credentials(self):
|
||||
headers = self._prepare_headers()
|
||||
options = {
|
||||
"url": "https://example.com",
|
||||
"includePaths": [],
|
||||
"excludePaths": [],
|
||||
"limit": 1,
|
||||
"scrapeOptions": {"onlyMainContent": True},
|
||||
}
|
||||
response = self._post_request(f"{self.base_url}/v1/crawl", options, headers)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
self._handle_error(response)
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return httpx.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
else:
|
||||
if response.text:
|
||||
error_message = json.loads(response.text).get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")
|
||||
44
dify/api/services/auth/jina.py
Normal file
44
dify/api/services/auth/jina.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
|
||||
self.api_key = credentials.get("config", {}).get("api_key", None)
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def validate_credentials(self):
|
||||
headers = self._prepare_headers()
|
||||
options = {
|
||||
"url": "https://example.com",
|
||||
}
|
||||
response = self._post_request("https://r.jina.ai", options, headers)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
self._handle_error(response)
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return httpx.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
else:
|
||||
if response.text:
|
||||
error_message = json.loads(response.text).get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")
|
||||
0
dify/api/services/auth/jina/__init__.py
Normal file
0
dify/api/services/auth/jina/__init__.py
Normal file
44
dify/api/services/auth/jina/jina.py
Normal file
44
dify/api/services/auth/jina/jina.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
|
||||
self.api_key = credentials.get("config", {}).get("api_key", None)
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def validate_credentials(self):
|
||||
headers = self._prepare_headers()
|
||||
options = {
|
||||
"url": "https://example.com",
|
||||
}
|
||||
response = self._post_request("https://r.jina.ai", options, headers)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
self._handle_error(response)
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return httpx.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
else:
|
||||
if response.text:
|
||||
error_message = json.loads(response.text).get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")
|
||||
0
dify/api/services/auth/watercrawl/__init__.py
Normal file
0
dify/api/services/auth/watercrawl/__init__.py
Normal file
44
dify/api/services/auth/watercrawl/watercrawl.py
Normal file
44
dify/api/services/auth/watercrawl/watercrawl.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import json
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
||||
class WatercrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "x-api-key":
|
||||
raise ValueError("Invalid auth type, WaterCrawl auth type must be x-api-key")
|
||||
self.api_key = credentials.get("config", {}).get("api_key", None)
|
||||
self.base_url = credentials.get("config", {}).get("base_url", "https://app.watercrawl.dev")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def validate_credentials(self):
|
||||
headers = self._prepare_headers()
|
||||
url = urljoin(self.base_url, "/api/v1/core/crawl-requests/")
|
||||
response = self._get_request(url, headers)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
self._handle_error(response)
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {"Content-Type": "application/json", "X-API-KEY": self.api_key}
|
||||
|
||||
def _get_request(self, url, headers):
|
||||
return httpx.get(url, headers=headers)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
else:
|
||||
if response.text:
|
||||
error_message = json.loads(response.text).get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")
|
||||
Reference in New Issue
Block a user