dify
This commit is contained in:
0
dify/api/extensions/__init__.py
Normal file
0
dify/api/extensions/__init__.py
Normal file
67
dify/api/extensions/ext_app_metrics.py
Normal file
67
dify/api/extensions/ext_app_metrics.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
|
||||
from flask import Response
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
@app.after_request
|
||||
def after_request(response): # pyright: ignore[reportUnusedFunction]
|
||||
"""Add Version headers to the response."""
|
||||
response.headers.add("X-Version", dify_config.project.version)
|
||||
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
|
||||
return response
|
||||
|
||||
@app.route("/health")
|
||||
def health(): # pyright: ignore[reportUnusedFunction]
|
||||
return Response(
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.project.version}),
|
||||
status=200,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
@app.route("/threads")
|
||||
def threads(): # pyright: ignore[reportUnusedFunction]
|
||||
num_threads = threading.active_count()
|
||||
threads = threading.enumerate()
|
||||
|
||||
thread_list = []
|
||||
for thread in threads:
|
||||
thread_name = thread.name
|
||||
thread_id = thread.ident
|
||||
is_alive = thread.is_alive()
|
||||
|
||||
thread_list.append(
|
||||
{
|
||||
"name": thread_name,
|
||||
"id": thread_id,
|
||||
"is_alive": is_alive,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"pid": os.getpid(),
|
||||
"thread_num": num_threads,
|
||||
"threads": thread_list,
|
||||
}
|
||||
|
||||
@app.route("/db-pool-stat")
|
||||
def pool_stat(): # pyright: ignore[reportUnusedFunction]
|
||||
from extensions.ext_database import db
|
||||
|
||||
engine = db.engine
|
||||
# TODO: Fix the type error
|
||||
# FIXME maybe its sqlalchemy issue
|
||||
return {
|
||||
"pid": os.getpid(),
|
||||
"pool_size": engine.pool.size(), # type: ignore
|
||||
"checked_in_connections": engine.pool.checkedin(), # type: ignore
|
||||
"checked_out_connections": engine.pool.checkedout(), # type: ignore
|
||||
"overflow_connections": engine.pool.overflow(), # type: ignore
|
||||
"connection_timeout": engine.pool.timeout(), # type: ignore
|
||||
"recycle_time": db.engine.pool._recycle, # type: ignore
|
||||
}
|
||||
67
dify/api/extensions/ext_blueprints.py
Normal file
67
dify/api/extensions/ext_blueprints.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from configs import dify_config
|
||||
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN, HEADER_NAME_PASSPORT
|
||||
from dify_app import DifyApp
|
||||
|
||||
BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEADER_NAME_PASSPORT)
|
||||
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
|
||||
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
# register blueprint routers
|
||||
|
||||
from flask_cors import CORS
|
||||
|
||||
from controllers.console import bp as console_app_bp
|
||||
from controllers.files import bp as files_bp
|
||||
from controllers.inner_api import bp as inner_api_bp
|
||||
from controllers.mcp import bp as mcp_bp
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.trigger import bp as trigger_bp
|
||||
from controllers.web import bp as web_bp
|
||||
|
||||
CORS(
|
||||
service_api_bp,
|
||||
allow_headers=list(SERVICE_API_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
CORS(
|
||||
web_bp,
|
||||
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
app.register_blueprint(web_bp)
|
||||
|
||||
CORS(
|
||||
console_app_bp,
|
||||
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
app.register_blueprint(console_app_bp)
|
||||
|
||||
CORS(
|
||||
files_bp,
|
||||
allow_headers=list(FILES_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
)
|
||||
app.register_blueprint(files_bp)
|
||||
|
||||
app.register_blueprint(inner_api_bp)
|
||||
app.register_blueprint(mcp_bp)
|
||||
|
||||
# Register trigger blueprint with CORS for webhook calls
|
||||
CORS(
|
||||
trigger_bp,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH", "HEAD"],
|
||||
)
|
||||
app.register_blueprint(trigger_bp)
|
||||
177
dify/api/extensions/ext_celery.py
Normal file
177
dify/api/extensions/ext_celery.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import ssl
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import pytz
|
||||
from celery import Celery, Task
|
||||
from celery.schedules import crontab
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def _get_celery_ssl_options() -> dict[str, Any] | None:
|
||||
"""Get SSL configuration for Celery broker/backend connections."""
|
||||
# Use REDIS_USE_SSL for consistency with the main Redis client
|
||||
# Only apply SSL if we're using Redis as broker/backend
|
||||
if not dify_config.REDIS_USE_SSL:
|
||||
return None
|
||||
|
||||
# Check if Celery is actually using Redis
|
||||
broker_is_redis = dify_config.CELERY_BROKER_URL and (
|
||||
dify_config.CELERY_BROKER_URL.startswith("redis://") or dify_config.CELERY_BROKER_URL.startswith("rediss://")
|
||||
)
|
||||
|
||||
if not broker_is_redis:
|
||||
return None
|
||||
|
||||
# Map certificate requirement strings to SSL constants
|
||||
cert_reqs_map = {
|
||||
"CERT_NONE": ssl.CERT_NONE,
|
||||
"CERT_OPTIONAL": ssl.CERT_OPTIONAL,
|
||||
"CERT_REQUIRED": ssl.CERT_REQUIRED,
|
||||
}
|
||||
|
||||
ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE)
|
||||
|
||||
ssl_options = {
|
||||
"ssl_cert_reqs": ssl_cert_reqs,
|
||||
"ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS,
|
||||
"ssl_certfile": dify_config.REDIS_SSL_CERTFILE,
|
||||
"ssl_keyfile": dify_config.REDIS_SSL_KEYFILE,
|
||||
}
|
||||
|
||||
return ssl_options
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> Celery:
|
||||
class FlaskTask(Task):
|
||||
def __call__(self, *args: object, **kwargs: object) -> object:
|
||||
with app.app_context():
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
broker_transport_options = {}
|
||||
|
||||
if dify_config.CELERY_USE_SENTINEL:
|
||||
broker_transport_options = {
|
||||
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
|
||||
"sentinel_kwargs": {
|
||||
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
|
||||
"password": dify_config.CELERY_SENTINEL_PASSWORD,
|
||||
},
|
||||
}
|
||||
|
||||
celery_app = Celery(
|
||||
app.name,
|
||||
task_cls=FlaskTask,
|
||||
broker=dify_config.CELERY_BROKER_URL,
|
||||
backend=dify_config.CELERY_BACKEND,
|
||||
)
|
||||
|
||||
celery_app.conf.update(
|
||||
result_backend=dify_config.CELERY_RESULT_BACKEND,
|
||||
broker_transport_options=broker_transport_options,
|
||||
broker_connection_retry_on_startup=True,
|
||||
worker_log_format=dify_config.LOG_FORMAT,
|
||||
worker_task_log_format=dify_config.LOG_FORMAT,
|
||||
worker_hijack_root_logger=False,
|
||||
timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"),
|
||||
task_ignore_result=True,
|
||||
)
|
||||
|
||||
# Apply SSL configuration if enabled
|
||||
ssl_options = _get_celery_ssl_options()
|
||||
if ssl_options:
|
||||
celery_app.conf.update(
|
||||
broker_use_ssl=ssl_options,
|
||||
# Also apply SSL to the backend if it's Redis
|
||||
redis_backend_use_ssl=ssl_options if dify_config.CELERY_BACKEND == "redis" else None,
|
||||
)
|
||||
|
||||
if dify_config.LOG_FILE:
|
||||
celery_app.conf.update(
|
||||
worker_logfile=dify_config.LOG_FILE,
|
||||
)
|
||||
|
||||
celery_app.set_default()
|
||||
app.extensions["celery"] = celery_app
|
||||
|
||||
imports = [
|
||||
"tasks.async_workflow_tasks", # trigger workers
|
||||
"tasks.trigger_processing_tasks", # async trigger processing
|
||||
]
|
||||
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
||||
|
||||
# if you add a new task, please add the switch to CeleryScheduleTasksConfig
|
||||
beat_schedule = {}
|
||||
if dify_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK:
|
||||
imports.append("schedule.clean_embedding_cache_task")
|
||||
beat_schedule["clean_embedding_cache_task"] = {
|
||||
"task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task",
|
||||
"schedule": crontab(minute="0", hour="2", day_of_month=f"*/{day}"),
|
||||
}
|
||||
if dify_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK:
|
||||
imports.append("schedule.clean_unused_datasets_task")
|
||||
beat_schedule["clean_unused_datasets_task"] = {
|
||||
"task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
|
||||
"schedule": crontab(minute="0", hour="3", day_of_month=f"*/{day}"),
|
||||
}
|
||||
if dify_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK:
|
||||
imports.append("schedule.create_tidb_serverless_task")
|
||||
beat_schedule["create_tidb_serverless_task"] = {
|
||||
"task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task",
|
||||
"schedule": crontab(minute="0", hour="*"),
|
||||
}
|
||||
if dify_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK:
|
||||
imports.append("schedule.update_tidb_serverless_status_task")
|
||||
beat_schedule["update_tidb_serverless_status_task"] = {
|
||||
"task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task",
|
||||
"schedule": timedelta(minutes=10),
|
||||
}
|
||||
if dify_config.ENABLE_CLEAN_MESSAGES:
|
||||
imports.append("schedule.clean_messages")
|
||||
beat_schedule["clean_messages"] = {
|
||||
"task": "schedule.clean_messages.clean_messages",
|
||||
"schedule": crontab(minute="0", hour="4", day_of_month=f"*/{day}"),
|
||||
}
|
||||
if dify_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:
|
||||
imports.append("schedule.mail_clean_document_notify_task")
|
||||
beat_schedule["mail_clean_document_notify_task"] = {
|
||||
"task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task",
|
||||
"schedule": crontab(minute="0", hour="10", day_of_week="1"),
|
||||
}
|
||||
if dify_config.ENABLE_DATASETS_QUEUE_MONITOR:
|
||||
imports.append("schedule.queue_monitor_task")
|
||||
beat_schedule["datasets-queue-monitor"] = {
|
||||
"task": "schedule.queue_monitor_task.queue_monitor_task",
|
||||
"schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30),
|
||||
}
|
||||
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
|
||||
imports.append("schedule.check_upgradable_plugin_task")
|
||||
imports.append("tasks.process_tenant_plugin_autoupgrade_check_task")
|
||||
beat_schedule["check_upgradable_plugin_task"] = {
|
||||
"task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task",
|
||||
"schedule": crontab(minute="*/15"),
|
||||
}
|
||||
if dify_config.WORKFLOW_LOG_CLEANUP_ENABLED:
|
||||
# 2:00 AM every day
|
||||
imports.append("schedule.clean_workflow_runlogs_precise")
|
||||
beat_schedule["clean_workflow_runlogs_precise"] = {
|
||||
"task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise",
|
||||
"schedule": crontab(minute="0", hour="2"),
|
||||
}
|
||||
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
|
||||
imports.append("schedule.workflow_schedule_task")
|
||||
beat_schedule["workflow_schedule_task"] = {
|
||||
"task": "schedule.workflow_schedule_task.poll_workflow_schedules",
|
||||
"schedule": timedelta(minutes=dify_config.WORKFLOW_SCHEDULE_POLLER_INTERVAL),
|
||||
}
|
||||
if dify_config.ENABLE_TRIGGER_PROVIDER_REFRESH_TASK:
|
||||
imports.append("schedule.trigger_provider_refresh_task")
|
||||
beat_schedule["trigger_provider_refresh"] = {
|
||||
"task": "schedule.trigger_provider_refresh_task.trigger_provider_refresh",
|
||||
"schedule": timedelta(minutes=dify_config.TRIGGER_PROVIDER_REFRESH_INTERVAL),
|
||||
}
|
||||
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
|
||||
|
||||
return celery_app
|
||||
9
dify/api/extensions/ext_code_based_extension.py
Normal file
9
dify/api/extensions/ext_code_based_extension.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from core.extension.extension import Extension
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
code_based_extension.init()
|
||||
|
||||
|
||||
code_based_extension = Extension()
|
||||
59
dify/api/extensions/ext_commands.py
Normal file
59
dify/api/extensions/ext_commands.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
from commands import (
|
||||
add_qdrant_index,
|
||||
cleanup_orphaned_draft_variables,
|
||||
clear_free_plan_tenant_expired_logs,
|
||||
clear_orphaned_file_records,
|
||||
convert_to_agent_apps,
|
||||
create_tenant,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
fix_app_site_missing,
|
||||
install_plugins,
|
||||
install_rag_pipeline_plugins,
|
||||
migrate_data_for_plugin,
|
||||
migrate_oss,
|
||||
old_metadata_migration,
|
||||
remove_orphaned_files_on_storage,
|
||||
reset_email,
|
||||
reset_encrypt_key_pair,
|
||||
reset_password,
|
||||
setup_datasource_oauth_client,
|
||||
setup_system_tool_oauth_client,
|
||||
setup_system_trigger_oauth_client,
|
||||
transform_datasource_credentials,
|
||||
upgrade_db,
|
||||
vdb_migrate,
|
||||
)
|
||||
|
||||
cmds_to_register = [
|
||||
reset_password,
|
||||
reset_email,
|
||||
reset_encrypt_key_pair,
|
||||
vdb_migrate,
|
||||
convert_to_agent_apps,
|
||||
add_qdrant_index,
|
||||
create_tenant,
|
||||
upgrade_db,
|
||||
fix_app_site_missing,
|
||||
migrate_data_for_plugin,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
install_plugins,
|
||||
old_metadata_migration,
|
||||
clear_free_plan_tenant_expired_logs,
|
||||
clear_orphaned_file_records,
|
||||
remove_orphaned_files_on_storage,
|
||||
setup_system_tool_oauth_client,
|
||||
setup_system_trigger_oauth_client,
|
||||
cleanup_orphaned_draft_variables,
|
||||
migrate_oss,
|
||||
setup_datasource_oauth_client,
|
||||
transform_datasource_credentials,
|
||||
install_rag_pipeline_plugins,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
app.cli.add_command(cmd)
|
||||
13
dify/api/extensions/ext_compress.py
Normal file
13
dify/api/extensions/ext_compress.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return dify_config.API_COMPRESSION_ENABLED
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
from flask_compress import Compress
|
||||
|
||||
compress = Compress()
|
||||
compress.init_app(app)
|
||||
55
dify/api/extensions/ext_database.py
Normal file
55
dify/api/extensions/ext_database.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import logging
|
||||
|
||||
import gevent
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.pool import Pool
|
||||
|
||||
from dify_app import DifyApp
|
||||
from models.engine import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global flag to avoid duplicate registration of event listener
|
||||
_gevent_compatibility_setup: bool = False
|
||||
|
||||
|
||||
def _safe_rollback(connection):
|
||||
"""Safely rollback database connection.
|
||||
|
||||
Args:
|
||||
connection: Database connection object
|
||||
"""
|
||||
try:
|
||||
connection.rollback()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
logger.exception("Failed to rollback connection")
|
||||
|
||||
|
||||
def _setup_gevent_compatibility():
|
||||
global _gevent_compatibility_setup # pylint: disable=global-statement
|
||||
|
||||
# Avoid duplicate registration
|
||||
if _gevent_compatibility_setup:
|
||||
return
|
||||
|
||||
@event.listens_for(Pool, "reset")
|
||||
def _safe_reset(dbapi_connection, connection_record, reset_state): # pyright: ignore[reportUnusedFunction]
|
||||
if reset_state.terminate_only:
|
||||
return
|
||||
|
||||
# Safe rollback for connection
|
||||
try:
|
||||
hub = gevent.get_hub()
|
||||
if hasattr(hub, "loop") and getattr(hub.loop, "in_callback", False):
|
||||
gevent.spawn_later(0, lambda: _safe_rollback(dbapi_connection))
|
||||
else:
|
||||
_safe_rollback(dbapi_connection)
|
||||
except (AttributeError, ImportError):
|
||||
_safe_rollback(dbapi_connection)
|
||||
|
||||
_gevent_compatibility_setup = True
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
db.init_app(app)
|
||||
_setup_gevent_compatibility()
|
||||
10
dify/api/extensions/ext_hosting_provider.py
Normal file
10
dify/api/extensions/ext_hosting_provider.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from core.hosting_configuration import HostingConfiguration
|
||||
|
||||
hosting_configuration = HostingConfiguration()
|
||||
|
||||
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
hosting_configuration.init_app(app)
|
||||
5
dify/api/extensions/ext_import_modules.py
Normal file
5
dify/api/extensions/ext_import_modules.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
from events import event_handlers # noqa: F401 # pyright: ignore[reportUnusedImport]
|
||||
93
dify/api/extensions/ext_logging.py
Normal file
93
dify/api/extensions/ext_logging.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
import flask
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
log_handlers: list[logging.Handler] = []
|
||||
log_file = dify_config.LOG_FILE
|
||||
if log_file:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_handlers.append(
|
||||
RotatingFileHandler(
|
||||
filename=log_file,
|
||||
maxBytes=dify_config.LOG_FILE_MAX_SIZE * 1024 * 1024,
|
||||
backupCount=dify_config.LOG_FILE_BACKUP_COUNT,
|
||||
)
|
||||
)
|
||||
|
||||
# Always add StreamHandler to log to console
|
||||
sh = logging.StreamHandler(sys.stdout)
|
||||
log_handlers.append(sh)
|
||||
|
||||
# Apply RequestIdFilter to all handlers
|
||||
for handler in log_handlers:
|
||||
handler.addFilter(RequestIdFilter())
|
||||
|
||||
logging.basicConfig(
|
||||
level=dify_config.LOG_LEVEL,
|
||||
format=dify_config.LOG_FORMAT,
|
||||
datefmt=dify_config.LOG_DATEFORMAT,
|
||||
handlers=log_handlers,
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Apply RequestIdFormatter to all handlers
|
||||
apply_request_id_formatter()
|
||||
|
||||
# Disable propagation for noisy loggers to avoid duplicate logs
|
||||
logging.getLogger("sqlalchemy.engine").propagate = False
|
||||
log_tz = dify_config.LOG_TZ
|
||||
if log_tz:
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
timezone = pytz.timezone(log_tz)
|
||||
|
||||
def time_converter(seconds):
|
||||
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
if handler.formatter:
|
||||
handler.formatter.converter = time_converter
|
||||
|
||||
|
||||
def get_request_id():
|
||||
if getattr(flask.g, "request_id", None):
|
||||
return flask.g.request_id
|
||||
|
||||
new_uuid = uuid.uuid4().hex[:10]
|
||||
flask.g.request_id = new_uuid
|
||||
|
||||
return new_uuid
|
||||
|
||||
|
||||
class RequestIdFilter(logging.Filter):
|
||||
# This is a logging filter that makes the request ID available for use in
|
||||
# the logging format. Note that we're checking if we're in a request
|
||||
# context, as we may want to log things before Flask is fully loaded.
|
||||
def filter(self, record):
|
||||
record.req_id = get_request_id() if flask.has_request_context() else ""
|
||||
return True
|
||||
|
||||
|
||||
class RequestIdFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
if not hasattr(record, "req_id"):
|
||||
record.req_id = ""
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def apply_request_id_formatter():
|
||||
for handler in logging.root.handlers:
|
||||
if handler.formatter:
|
||||
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)
|
||||
127
dify/api/extensions/ext_login.py
Normal file
127
dify/api/extensions/ext_login.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import json
|
||||
|
||||
import flask_login
|
||||
from flask import Response, request
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HEADER_NAME_APP_CODE
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token, extract_webapp_passport
|
||||
from models import Account, Tenant, TenantAccountJoin
|
||||
from models.model import AppMCPServer, EndUser
|
||||
from services.account_service import AccountService
|
||||
|
||||
login_manager = flask_login.LoginManager()
|
||||
|
||||
|
||||
# Flask-Login configuration
|
||||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
# Skip authentication for documentation endpoints
|
||||
if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
|
||||
return None
|
||||
|
||||
auth_token = extract_access_token(request)
|
||||
|
||||
# Check for admin API key authentication first
|
||||
if dify_config.ADMIN_API_KEY_ENABLE and auth_token:
|
||||
admin_api_key = dify_config.ADMIN_API_KEY
|
||||
if admin_api_key and admin_api_key == auth_token:
|
||||
workspace_id = request.headers.get("X-WORKSPACE-ID")
|
||||
if workspace_id:
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == workspace_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role == "owner")
|
||||
.one_or_none()
|
||||
)
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.query(Account).filter_by(id=ta.account_id).first()
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
return account
|
||||
|
||||
if request.blueprint in {"console", "inner_api"}:
|
||||
if not auth_token:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get("user_id")
|
||||
source = decoded.get("token_source")
|
||||
if source:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
if not user_id:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||
return logged_in_account
|
||||
elif request.blueprint == "web":
|
||||
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||
webapp_token = extract_webapp_passport(app_code, request) if app_code else None
|
||||
|
||||
if webapp_token:
|
||||
decoded = PassportService().verify(webapp_token)
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
if not end_user_id:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
|
||||
if not end_user:
|
||||
raise NotFound("End user not found.")
|
||||
return end_user
|
||||
else:
|
||||
if not auth_token:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
decoded = PassportService().verify(auth_token)
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
if end_user_id:
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
|
||||
if not end_user:
|
||||
raise NotFound("End user not found.")
|
||||
return end_user
|
||||
else:
|
||||
raise Unauthorized("Invalid Authorization token for web API.")
|
||||
elif request.blueprint == "mcp":
|
||||
server_code = request.view_args.get("server_code") if request.view_args else None
|
||||
if not server_code:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
app_mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
||||
if not app_mcp_server:
|
||||
raise NotFound("App MCP server not found.")
|
||||
end_user = (
|
||||
db.session.query(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").first()
|
||||
)
|
||||
if not end_user:
|
||||
raise NotFound("End user not found.")
|
||||
return end_user
|
||||
|
||||
|
||||
@user_logged_in.connect
|
||||
@user_loaded_from_request.connect
|
||||
def on_user_logged_in(_sender, user):
|
||||
"""Called when a user logged in.
|
||||
|
||||
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
|
||||
through the load_user method, which calls account.set_tenant_id().
|
||||
"""
|
||||
# tenant_id context variable removed - using current_user.current_tenant_id directly
|
||||
pass
|
||||
|
||||
|
||||
@login_manager.unauthorized_handler
|
||||
def unauthorized_handler():
|
||||
"""Handle unauthorized requests."""
|
||||
return Response(
|
||||
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
|
||||
status=401,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
login_manager.init_app(app)
|
||||
107
dify/api/extensions/ext_mail.py
Normal file
107
dify/api/extensions/ext_mail.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import logging
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Mail:
|
||||
def __init__(self):
|
||||
self._client = None
|
||||
self._default_send_from = None
|
||||
|
||||
def is_inited(self) -> bool:
|
||||
return self._client is not None
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
mail_type = dify_config.MAIL_TYPE
|
||||
if not mail_type:
|
||||
logger.warning("MAIL_TYPE is not set")
|
||||
return
|
||||
|
||||
if dify_config.MAIL_DEFAULT_SEND_FROM:
|
||||
self._default_send_from = dify_config.MAIL_DEFAULT_SEND_FROM
|
||||
|
||||
match mail_type:
|
||||
case "resend":
|
||||
import resend
|
||||
|
||||
api_key = dify_config.RESEND_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError("RESEND_API_KEY is not set")
|
||||
|
||||
api_url = dify_config.RESEND_API_URL
|
||||
if api_url:
|
||||
resend.api_url = api_url
|
||||
|
||||
resend.api_key = api_key
|
||||
self._client = resend.Emails
|
||||
case "smtp":
|
||||
from libs.smtp import SMTPClient
|
||||
|
||||
if not dify_config.SMTP_SERVER or not dify_config.SMTP_PORT:
|
||||
raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type")
|
||||
if not dify_config.SMTP_USE_TLS and dify_config.SMTP_OPPORTUNISTIC_TLS:
|
||||
raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS")
|
||||
self._client = SMTPClient(
|
||||
server=dify_config.SMTP_SERVER,
|
||||
port=dify_config.SMTP_PORT,
|
||||
username=dify_config.SMTP_USERNAME or "",
|
||||
password=dify_config.SMTP_PASSWORD or "",
|
||||
_from=dify_config.MAIL_DEFAULT_SEND_FROM or "",
|
||||
use_tls=dify_config.SMTP_USE_TLS,
|
||||
opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS,
|
||||
)
|
||||
case "sendgrid":
|
||||
from libs.sendgrid import SendGridClient
|
||||
|
||||
if not dify_config.SENDGRID_API_KEY:
|
||||
raise ValueError("SENDGRID_API_KEY is required for SendGrid mail type")
|
||||
|
||||
self._client = SendGridClient(
|
||||
sendgrid_api_key=dify_config.SENDGRID_API_KEY, _from=dify_config.MAIL_DEFAULT_SEND_FROM or ""
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported mail type {mail_type}")
|
||||
|
||||
def send(self, to: str, subject: str, html: str, from_: str | None = None):
|
||||
if not self._client:
|
||||
raise ValueError("Mail client is not initialized")
|
||||
|
||||
if not from_ and self._default_send_from:
|
||||
from_ = self._default_send_from
|
||||
|
||||
if not from_:
|
||||
raise ValueError("mail from is not set")
|
||||
|
||||
if not to:
|
||||
raise ValueError("mail to is not set")
|
||||
|
||||
if not subject:
|
||||
raise ValueError("mail subject is not set")
|
||||
|
||||
if not html:
|
||||
raise ValueError("mail html is not set")
|
||||
|
||||
self._client.send(
|
||||
{
|
||||
"from": from_,
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"html": html,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
mail.init_app(app)
|
||||
|
||||
|
||||
mail = Mail()
|
||||
9
dify/api/extensions/ext_migrate.py
Normal file
9
dify/api/extensions/ext_migrate.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
import flask_migrate
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
flask_migrate.Migrate(app, db)
|
||||
8
dify/api/extensions/ext_orjson.py
Normal file
8
dify/api/extensions/ext_orjson.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from flask_orjson import OrjsonProvider
|
||||
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
"""Initialize Flask-Orjson extension for faster JSON serialization"""
|
||||
app.json = OrjsonProvider(app)
|
||||
259
dify/api/extensions/ext_otel.py
Normal file
259
dify/api/extensions/ext_otel.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import atexit
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import socket
|
||||
import sys
|
||||
from typing import Union
|
||||
|
||||
import flask
|
||||
from celery.signals import worker_init
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account, EndUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@user_logged_in.connect
|
||||
@user_loaded_from_request.connect
|
||||
def on_user_loaded(_sender, user: Union["Account", "EndUser"]):
|
||||
if dify_config.ENABLE_OTEL:
|
||||
from opentelemetry.trace import get_current_span
|
||||
|
||||
if user:
|
||||
try:
|
||||
current_span = get_current_span()
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
return
|
||||
if current_span:
|
||||
current_span.set_attribute("service.tenant.id", tenant_id)
|
||||
current_span.set_attribute("service.user.id", user.id)
|
||||
except Exception:
|
||||
logger.exception("Error setting tenant and user attributes")
|
||||
pass
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
from opentelemetry.semconv.trace import SpanAttributes
|
||||
|
||||
def is_celery_worker():
|
||||
return "celery" in sys.argv[0].lower()
|
||||
|
||||
def instrument_exception_logging():
|
||||
exception_handler = ExceptionLoggingHandler()
|
||||
logging.getLogger().addHandler(exception_handler)
|
||||
|
||||
def init_flask_instrumentor(app: DifyApp):
|
||||
meter = get_meter("http_metrics", version=dify_config.project.version)
|
||||
_http_response_counter = meter.create_counter(
|
||||
"http.server.response.count",
|
||||
description="Total number of HTTP responses by status code, method and target",
|
||||
unit="{response}",
|
||||
)
|
||||
|
||||
def response_hook(span: Span, status: str, response_headers: list):
|
||||
if span and span.is_recording():
|
||||
try:
|
||||
if status.startswith("2"):
|
||||
span.set_status(StatusCode.OK)
|
||||
else:
|
||||
span.set_status(StatusCode.ERROR, status)
|
||||
|
||||
status = status.split(" ")[0]
|
||||
status_code = int(status)
|
||||
status_class = f"{status_code // 100}xx"
|
||||
attributes: dict[str, str | int] = {"status_code": status_code, "status_class": status_class}
|
||||
request = flask.request
|
||||
if request and request.url_rule:
|
||||
attributes[SpanAttributes.HTTP_TARGET] = str(request.url_rule.rule)
|
||||
if request and request.method:
|
||||
attributes[SpanAttributes.HTTP_METHOD] = str(request.method)
|
||||
_http_response_counter.add(1, attributes)
|
||||
except Exception:
|
||||
logger.exception("Error setting status and attributes")
|
||||
pass
|
||||
|
||||
instrumentor = FlaskInstrumentor()
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Initializing Flask instrumentor")
|
||||
instrumentor.instrument_app(app, response_hook=response_hook)
|
||||
|
||||
def init_sqlalchemy_instrumentor(app: DifyApp):
|
||||
with app.app_context():
|
||||
engines = list(app.extensions["sqlalchemy"].engines.values())
|
||||
SQLAlchemyInstrumentor().instrument(enable_commenter=True, engines=engines)
|
||||
|
||||
def setup_context_propagation():
|
||||
# Configure propagators
|
||||
set_global_textmap(
|
||||
CompositePropagator(
|
||||
[
|
||||
TraceContextTextMapPropagator(), # W3C trace context
|
||||
B3Format(), # B3 propagation (used by many systems)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def shutdown_tracer():
|
||||
provider = trace.get_tracer_provider()
|
||||
if hasattr(provider, "force_flush"):
|
||||
provider.force_flush()
|
||||
|
||||
class ExceptionLoggingHandler(logging.Handler):
|
||||
"""Custom logging handler that creates spans for logging.exception() calls"""
|
||||
|
||||
def emit(self, record: logging.LogRecord):
|
||||
with contextlib.suppress(Exception):
|
||||
if record.exc_info:
|
||||
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
|
||||
with tracer.start_as_current_span(
|
||||
"log.exception",
|
||||
attributes={
|
||||
"log.level": record.levelname,
|
||||
"log.message": record.getMessage(),
|
||||
"log.logger": record.name,
|
||||
"log.file.path": record.pathname,
|
||||
"log.file.line": record.lineno,
|
||||
},
|
||||
) as span:
|
||||
span.set_status(StatusCode.ERROR)
|
||||
if record.exc_info[1]:
|
||||
span.record_exception(record.exc_info[1])
|
||||
span.set_attribute("exception.message", str(record.exc_info[1]))
|
||||
if record.exc_info[0]:
|
||||
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
|
||||
from opentelemetry.instrumentation.celery import CeleryInstrumentor
|
||||
from opentelemetry.instrumentation.flask import FlaskInstrumentor
|
||||
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
||||
from opentelemetry.instrumentation.redis import RedisInstrumentor
|
||||
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
|
||||
from opentelemetry.metrics import get_meter, get_meter_provider, set_meter_provider
|
||||
from opentelemetry.propagate import set_global_textmap
|
||||
from opentelemetry.propagators.b3 import B3Format
|
||||
from opentelemetry.propagators.composite import CompositePropagator
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import ConsoleMetricExporter, PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import (
|
||||
BatchSpanProcessor,
|
||||
ConsoleSpanExporter,
|
||||
)
|
||||
from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.trace import Span, get_tracer_provider, set_tracer_provider
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from opentelemetry.trace.status import StatusCode
|
||||
|
||||
setup_context_propagation()
|
||||
# Initialize OpenTelemetry
|
||||
# Follow Semantic Convertions 1.32.0 to define resource attributes
|
||||
resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: dify_config.APPLICATION_NAME,
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.PROCESS_PID: os.getpid(),
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
ResourceAttributes.HOST_ARCH: platform.machine(),
|
||||
"custom.deployment.git_commit": dify_config.COMMIT_SHA,
|
||||
ResourceAttributes.HOST_ID: platform.node(),
|
||||
ResourceAttributes.OS_TYPE: platform.system().lower(),
|
||||
ResourceAttributes.OS_DESCRIPTION: platform.platform(),
|
||||
ResourceAttributes.OS_VERSION: platform.version(),
|
||||
}
|
||||
)
|
||||
sampler = ParentBasedTraceIdRatio(dify_config.OTEL_SAMPLING_RATE)
|
||||
provider = TracerProvider(resource=resource, sampler=sampler)
|
||||
set_tracer_provider(provider)
|
||||
exporter: Union[GRPCSpanExporter, HTTPSpanExporter, ConsoleSpanExporter]
|
||||
metric_exporter: Union[GRPCMetricExporter, HTTPMetricExporter, ConsoleMetricExporter]
|
||||
protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower()
|
||||
if dify_config.OTEL_EXPORTER_TYPE == "otlp":
|
||||
if protocol == "grpc":
|
||||
exporter = GRPCSpanExporter(
|
||||
endpoint=dify_config.OTLP_BASE_ENDPOINT,
|
||||
# Header field names must consist of lowercase letters, check RFC7540
|
||||
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
|
||||
insecure=True,
|
||||
)
|
||||
metric_exporter = GRPCMetricExporter(
|
||||
endpoint=dify_config.OTLP_BASE_ENDPOINT,
|
||||
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
|
||||
insecure=True,
|
||||
)
|
||||
else:
|
||||
headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None
|
||||
|
||||
trace_endpoint = dify_config.OTLP_TRACE_ENDPOINT
|
||||
if not trace_endpoint:
|
||||
trace_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/traces"
|
||||
exporter = HTTPSpanExporter(
|
||||
endpoint=trace_endpoint,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
metric_endpoint = dify_config.OTLP_METRIC_ENDPOINT
|
||||
if not metric_endpoint:
|
||||
metric_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics"
|
||||
metric_exporter = HTTPMetricExporter(
|
||||
endpoint=metric_endpoint,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
exporter = ConsoleSpanExporter()
|
||||
metric_exporter = ConsoleMetricExporter()
|
||||
|
||||
provider.add_span_processor(
|
||||
BatchSpanProcessor(
|
||||
exporter,
|
||||
max_queue_size=dify_config.OTEL_MAX_QUEUE_SIZE,
|
||||
schedule_delay_millis=dify_config.OTEL_BATCH_EXPORT_SCHEDULE_DELAY,
|
||||
max_export_batch_size=dify_config.OTEL_MAX_EXPORT_BATCH_SIZE,
|
||||
export_timeout_millis=dify_config.OTEL_BATCH_EXPORT_TIMEOUT,
|
||||
)
|
||||
)
|
||||
reader = PeriodicExportingMetricReader(
|
||||
metric_exporter,
|
||||
export_interval_millis=dify_config.OTEL_METRIC_EXPORT_INTERVAL,
|
||||
export_timeout_millis=dify_config.OTEL_METRIC_EXPORT_TIMEOUT,
|
||||
)
|
||||
set_meter_provider(MeterProvider(resource=resource, metric_readers=[reader]))
|
||||
if not is_celery_worker():
|
||||
init_flask_instrumentor(app)
|
||||
CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument()
|
||||
instrument_exception_logging()
|
||||
init_sqlalchemy_instrumentor(app)
|
||||
RedisInstrumentor().instrument()
|
||||
HTTPXClientInstrumentor().instrument()
|
||||
atexit.register(shutdown_tracer)
|
||||
|
||||
|
||||
def is_enabled():
|
||||
return dify_config.ENABLE_OTEL
|
||||
|
||||
|
||||
@worker_init.connect(weak=False)
|
||||
def init_celery_worker(*args, **kwargs):
|
||||
if dify_config.ENABLE_OTEL:
|
||||
from opentelemetry.instrumentation.celery import CeleryInstrumentor
|
||||
from opentelemetry.metrics import get_meter_provider
|
||||
from opentelemetry.trace import get_tracer_provider
|
||||
|
||||
tracer_provider = get_tracer_provider()
|
||||
metric_provider = get_meter_provider()
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Initializing OpenTelemetry for Celery worker")
|
||||
CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument()
|
||||
9
dify/api/extensions/ext_proxy_fix.py
Normal file
9
dify/api/extensions/ext_proxy_fix.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
|
||||
from werkzeug.middleware.proxy_fix import ProxyFix
|
||||
|
||||
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore[method-assign]
|
||||
268
dify/api/extensions/ext_redis.py
Normal file
268
dify/api/extensions/ext_redis.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import functools
|
||||
import logging
|
||||
import ssl
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import redis
|
||||
from redis import RedisError
|
||||
from redis.cache import CacheConfig
|
||||
from redis.cluster import ClusterNode, RedisCluster
|
||||
from redis.connection import Connection, SSLConnection
|
||||
from redis.sentinel import Sentinel
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.lock import Lock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisClientWrapper:
|
||||
"""
|
||||
A wrapper class for the Redis client that addresses the issue where the global
|
||||
`redis_client` variable cannot be updated when a new Redis instance is returned
|
||||
by Sentinel.
|
||||
|
||||
This class allows for deferred initialization of the Redis client, enabling the
|
||||
client to be re-initialized with a new instance when necessary. This is particularly
|
||||
useful in scenarios where the Redis instance may change dynamically, such as during
|
||||
a failover in a Sentinel-managed Redis setup.
|
||||
|
||||
Attributes:
|
||||
_client: The actual Redis client instance. It remains None until
|
||||
initialized with the `initialize` method.
|
||||
|
||||
Methods:
|
||||
initialize(client): Initializes the Redis client if it hasn't been initialized already.
|
||||
__getattr__(item): Delegates attribute access to the Redis client, raising an error
|
||||
if the client is not initialized.
|
||||
"""
|
||||
|
||||
_client: Union[redis.Redis, RedisCluster, None]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._client = None
|
||||
|
||||
def initialize(self, client: Union[redis.Redis, RedisCluster]) -> None:
|
||||
if self._client is None:
|
||||
self._client = client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Type hints for IDE support and static analysis
|
||||
# These are not executed at runtime but provide type information
|
||||
def get(self, name: str | bytes) -> Any: ...
|
||||
|
||||
def set(
|
||||
self,
|
||||
name: str | bytes,
|
||||
value: Any,
|
||||
ex: int | None = None,
|
||||
px: int | None = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
keepttl: bool = False,
|
||||
get: bool = False,
|
||||
exat: int | None = None,
|
||||
pxat: int | None = None,
|
||||
) -> Any: ...
|
||||
|
||||
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ...
|
||||
def setnx(self, name: str | bytes, value: Any) -> Any: ...
|
||||
def delete(self, *names: str | bytes) -> Any: ...
|
||||
def incr(self, name: str | bytes, amount: int = 1) -> Any: ...
|
||||
def expire(
|
||||
self,
|
||||
name: str | bytes,
|
||||
time: int | timedelta,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any: ...
|
||||
def lock(
|
||||
self,
|
||||
name: str,
|
||||
timeout: float | None = None,
|
||||
sleep: float = 0.1,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
thread_local: bool = True,
|
||||
) -> Lock: ...
|
||||
def zadd(
|
||||
self,
|
||||
name: str | bytes,
|
||||
mapping: dict[str | bytes | int | float, float | int | str | bytes],
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
ch: bool = False,
|
||||
incr: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any: ...
|
||||
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
|
||||
def zcard(self, name: str | bytes) -> Any: ...
|
||||
def getdel(self, name: str | bytes) -> Any: ...
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
if self._client is None:
|
||||
raise RuntimeError("Redis client is not initialized. Call init_app first.")
|
||||
return getattr(self._client, item)
|
||||
|
||||
|
||||
redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
|
||||
|
||||
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
||||
"""Get SSL configuration for Redis connection."""
|
||||
if not dify_config.REDIS_USE_SSL:
|
||||
return Connection, {}
|
||||
|
||||
cert_reqs_map = {
|
||||
"CERT_NONE": ssl.CERT_NONE,
|
||||
"CERT_OPTIONAL": ssl.CERT_OPTIONAL,
|
||||
"CERT_REQUIRED": ssl.CERT_REQUIRED,
|
||||
}
|
||||
ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE)
|
||||
|
||||
ssl_kwargs = {
|
||||
"ssl_cert_reqs": ssl_cert_reqs,
|
||||
"ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS,
|
||||
"ssl_certfile": dify_config.REDIS_SSL_CERTFILE,
|
||||
"ssl_keyfile": dify_config.REDIS_SSL_KEYFILE,
|
||||
}
|
||||
|
||||
return SSLConnection, ssl_kwargs
|
||||
|
||||
|
||||
def _get_cache_configuration() -> CacheConfig | None:
|
||||
"""Get client-side cache configuration if enabled."""
|
||||
if not dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE:
|
||||
return None
|
||||
|
||||
resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL
|
||||
if resp_protocol < 3:
|
||||
raise ValueError("Client side cache is only supported in RESP3")
|
||||
|
||||
return CacheConfig()
|
||||
|
||||
|
||||
def _get_base_redis_params() -> dict[str, Any]:
|
||||
"""Get base Redis connection parameters."""
|
||||
return {
|
||||
"username": dify_config.REDIS_USERNAME,
|
||||
"password": dify_config.REDIS_PASSWORD or None,
|
||||
"db": dify_config.REDIS_DB,
|
||||
"encoding": "utf-8",
|
||||
"encoding_errors": "strict",
|
||||
"decode_responses": False,
|
||||
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
"cache_config": _get_cache_configuration(),
|
||||
}
|
||||
|
||||
|
||||
def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
|
||||
"""Create Redis client using Sentinel configuration."""
|
||||
if not dify_config.REDIS_SENTINELS:
|
||||
raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True")
|
||||
|
||||
if not dify_config.REDIS_SENTINEL_SERVICE_NAME:
|
||||
raise ValueError("REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True")
|
||||
|
||||
sentinel_hosts = [(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")]
|
||||
|
||||
sentinel = Sentinel(
|
||||
sentinel_hosts,
|
||||
sentinel_kwargs={
|
||||
"socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT,
|
||||
"username": dify_config.REDIS_SENTINEL_USERNAME,
|
||||
"password": dify_config.REDIS_SENTINEL_PASSWORD,
|
||||
},
|
||||
)
|
||||
|
||||
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
|
||||
return master
|
||||
|
||||
|
||||
def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
|
||||
"""Create Redis cluster client."""
|
||||
if not dify_config.REDIS_CLUSTERS:
|
||||
raise ValueError("REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True")
|
||||
|
||||
nodes = [
|
||||
ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1]))
|
||||
for node in dify_config.REDIS_CLUSTERS.split(",")
|
||||
]
|
||||
|
||||
cluster: RedisCluster = RedisCluster(
|
||||
startup_nodes=nodes,
|
||||
password=dify_config.REDIS_CLUSTERS_PASSWORD,
|
||||
protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
cache_config=_get_cache_configuration(),
|
||||
)
|
||||
return cluster
|
||||
|
||||
|
||||
def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
|
||||
"""Create standalone Redis client."""
|
||||
connection_class, ssl_kwargs = _get_ssl_configuration()
|
||||
|
||||
redis_params.update(
|
||||
{
|
||||
"host": dify_config.REDIS_HOST,
|
||||
"port": dify_config.REDIS_PORT,
|
||||
"connection_class": connection_class,
|
||||
}
|
||||
)
|
||||
|
||||
if ssl_kwargs:
|
||||
redis_params.update(ssl_kwargs)
|
||||
|
||||
pool = redis.ConnectionPool(**redis_params)
|
||||
client: redis.Redis = redis.Redis(connection_pool=pool)
|
||||
return client
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
"""Initialize Redis client and attach it to the app."""
|
||||
global redis_client
|
||||
|
||||
# Determine Redis mode and create appropriate client
|
||||
if dify_config.REDIS_USE_SENTINEL:
|
||||
redis_params = _get_base_redis_params()
|
||||
client = _create_sentinel_client(redis_params)
|
||||
elif dify_config.REDIS_USE_CLUSTERS:
|
||||
client = _create_cluster_client()
|
||||
else:
|
||||
redis_params = _get_base_redis_params()
|
||||
client = _create_standalone_client(redis_params)
|
||||
|
||||
# Initialize the wrapper and attach to app
|
||||
redis_client.initialize(client)
|
||||
app.extensions["redis"] = redis_client
|
||||
|
||||
|
||||
def redis_fallback(default_return: Any | None = None):
|
||||
"""
|
||||
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
|
||||
|
||||
Args:
|
||||
default_return: The value to return when a Redis operation fails. Defaults to None.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except RedisError as e:
|
||||
func_name = getattr(func, "__name__", "Unknown")
|
||||
logger.warning("Redis operation failed in %s: %s", func_name, str(e), exc_info=True)
|
||||
return default_return
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
73
dify/api/extensions/ext_request_logging.py
Normal file
73
dify/api/extensions/ext_request_logging.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import flask
|
||||
import werkzeug.http
|
||||
from flask import Flask
|
||||
from flask.signals import request_finished, request_started
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_content_type_json(content_type: str) -> bool:
|
||||
if not content_type:
|
||||
return False
|
||||
content_type_no_option, _ = werkzeug.http.parse_options_header(content_type)
|
||||
return content_type_no_option.lower() == "application/json"
|
||||
|
||||
|
||||
def _log_request_started(_sender, **_extra):
|
||||
"""Log the start of a request."""
|
||||
if not logger.isEnabledFor(logging.DEBUG):
|
||||
return
|
||||
|
||||
request = flask.request
|
||||
if not (_is_content_type_json(request.content_type) and request.data):
|
||||
logger.debug("Received Request %s -> %s", request.method, request.path)
|
||||
return
|
||||
try:
|
||||
json_data = json.loads(request.data)
|
||||
except (TypeError, ValueError):
|
||||
logger.exception("Failed to parse JSON request")
|
||||
return
|
||||
formatted_json = json.dumps(json_data, ensure_ascii=False, indent=2)
|
||||
logger.debug(
|
||||
"Received Request %s -> %s, Request Body:\n%s",
|
||||
request.method,
|
||||
request.path,
|
||||
formatted_json,
|
||||
)
|
||||
|
||||
|
||||
def _log_request_finished(_sender, response, **_extra):
|
||||
"""Log the end of a request."""
|
||||
if not logger.isEnabledFor(logging.DEBUG) or response is None:
|
||||
return
|
||||
|
||||
if not _is_content_type_json(response.content_type):
|
||||
logger.debug("Response %s %s", response.status, response.content_type)
|
||||
return
|
||||
|
||||
response_data = response.get_data(as_text=True)
|
||||
try:
|
||||
json_data = json.loads(response_data)
|
||||
except (TypeError, ValueError):
|
||||
logger.exception("Failed to parse JSON response")
|
||||
return
|
||||
formatted_json = json.dumps(json_data, ensure_ascii=False, indent=2)
|
||||
logger.debug(
|
||||
"Response %s %s, Response Body:\n%s",
|
||||
response.status,
|
||||
response.content_type,
|
||||
formatted_json,
|
||||
)
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
"""Initialize the request logging extension."""
|
||||
if not dify_config.ENABLE_REQUEST_LOGGING:
|
||||
return
|
||||
request_started.connect(_log_request_started, app)
|
||||
request_finished.connect(_log_request_finished, app)
|
||||
38
dify/api/extensions/ext_sentry.py
Normal file
38
dify/api/extensions/ext_sentry.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
if dify_config.SENTRY_DSN:
|
||||
import sentry_sdk
|
||||
from langfuse import parse_error
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sentry_sdk.integrations.flask import FlaskIntegration
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeRateLimitError
|
||||
|
||||
def before_send(event, hint):
|
||||
if "exc_info" in hint:
|
||||
_, exc_value, _ = hint["exc_info"]
|
||||
if parse_error.defaultErrorResponse in str(exc_value):
|
||||
return None
|
||||
|
||||
return event
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=dify_config.SENTRY_DSN,
|
||||
integrations=[FlaskIntegration(), CeleryIntegration()],
|
||||
ignore_errors=[
|
||||
HTTPException,
|
||||
ValueError,
|
||||
FileNotFoundError,
|
||||
InvokeRateLimitError,
|
||||
parse_error.defaultErrorResponse,
|
||||
],
|
||||
traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE,
|
||||
profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE,
|
||||
environment=dify_config.DEPLOY_ENV,
|
||||
release=f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
before_send=before_send,
|
||||
)
|
||||
6
dify/api/extensions/ext_set_secretkey.py
Normal file
6
dify/api/extensions/ext_set_secretkey.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
app.secret_key = dify_config.SECRET_KEY
|
||||
126
dify/api/extensions/ext_storage.py
Normal file
126
dify/api/extensions/ext_storage.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import Literal, Union, overload
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Storage:
|
||||
def init_app(self, app: Flask):
|
||||
storage_factory = self.get_storage_factory(dify_config.STORAGE_TYPE)
|
||||
with app.app_context():
|
||||
self.storage_runner = storage_factory()
|
||||
|
||||
@staticmethod
|
||||
def get_storage_factory(storage_type: str) -> Callable[[], BaseStorage]:
|
||||
match storage_type:
|
||||
case StorageType.S3:
|
||||
from extensions.storage.aws_s3_storage import AwsS3Storage
|
||||
|
||||
return AwsS3Storage
|
||||
case StorageType.OPENDAL:
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
|
||||
return lambda: OpenDALStorage(dify_config.OPENDAL_SCHEME)
|
||||
case StorageType.LOCAL:
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
|
||||
return lambda: OpenDALStorage(scheme="fs", root=dify_config.STORAGE_LOCAL_PATH)
|
||||
case StorageType.AZURE_BLOB:
|
||||
from extensions.storage.azure_blob_storage import AzureBlobStorage
|
||||
|
||||
return AzureBlobStorage
|
||||
case StorageType.ALIYUN_OSS:
|
||||
from extensions.storage.aliyun_oss_storage import AliyunOssStorage
|
||||
|
||||
return AliyunOssStorage
|
||||
case StorageType.GOOGLE_STORAGE:
|
||||
from extensions.storage.google_cloud_storage import GoogleCloudStorage
|
||||
|
||||
return GoogleCloudStorage
|
||||
case StorageType.TENCENT_COS:
|
||||
from extensions.storage.tencent_cos_storage import TencentCosStorage
|
||||
|
||||
return TencentCosStorage
|
||||
case StorageType.OCI_STORAGE:
|
||||
from extensions.storage.oracle_oci_storage import OracleOCIStorage
|
||||
|
||||
return OracleOCIStorage
|
||||
case StorageType.HUAWEI_OBS:
|
||||
from extensions.storage.huawei_obs_storage import HuaweiObsStorage
|
||||
|
||||
return HuaweiObsStorage
|
||||
case StorageType.BAIDU_OBS:
|
||||
from extensions.storage.baidu_obs_storage import BaiduObsStorage
|
||||
|
||||
return BaiduObsStorage
|
||||
case StorageType.VOLCENGINE_TOS:
|
||||
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
|
||||
|
||||
return VolcengineTosStorage
|
||||
case StorageType.SUPABASE:
|
||||
from extensions.storage.supabase_storage import SupabaseStorage
|
||||
|
||||
return SupabaseStorage
|
||||
case StorageType.CLICKZETTA_VOLUME:
|
||||
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
|
||||
ClickZettaVolumeConfig,
|
||||
ClickZettaVolumeStorage,
|
||||
)
|
||||
|
||||
def create_clickzetta_volume_storage():
|
||||
# ClickZettaVolumeConfig will automatically read from environment variables
|
||||
# and fallback to CLICKZETTA_* config if CLICKZETTA_VOLUME_* is not set
|
||||
volume_config = ClickZettaVolumeConfig()
|
||||
return ClickZettaVolumeStorage(volume_config)
|
||||
|
||||
return create_clickzetta_volume_storage
|
||||
case _:
|
||||
raise ValueError(f"unsupported storage type {storage_type}")
|
||||
|
||||
def save(self, filename: str, data: bytes):
|
||||
self.storage_runner.save(filename, data)
|
||||
|
||||
@overload
|
||||
def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ...
|
||||
|
||||
@overload
|
||||
def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ...
|
||||
|
||||
def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]:
|
||||
if stream:
|
||||
return self.load_stream(filename)
|
||||
else:
|
||||
return self.load_once(filename)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
return self.storage_runner.load_once(filename)
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
return self.storage_runner.load_stream(filename)
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
self.storage_runner.download(filename, target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
return self.storage_runner.exists(filename)
|
||||
|
||||
def delete(self, filename: str):
|
||||
return self.storage_runner.delete(filename)
|
||||
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
return self.storage_runner.scan(path, files=files, directories=directories)
|
||||
|
||||
|
||||
storage = Storage()
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
storage.init_app(app)
|
||||
11
dify/api/extensions/ext_timezone.py
Normal file
11
dify/api/extensions/ext_timezone.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
os.environ["TZ"] = "UTC"
|
||||
# windows platform not support tzset
|
||||
if hasattr(time, "tzset"):
|
||||
time.tzset()
|
||||
7
dify/api/extensions/ext_warnings.py
Normal file
7
dify/api/extensions/ext_warnings.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore", ResourceWarning)
|
||||
56
dify/api/extensions/storage/aliyun_oss_storage.py
Normal file
56
dify/api/extensions/storage/aliyun_oss_storage.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import posixpath
|
||||
from collections.abc import Generator
|
||||
|
||||
import oss2 as aliyun_s3
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
||||
class AliyunOssStorage(BaseStorage):
|
||||
"""Implementation for Aliyun OSS storage."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bucket_name = dify_config.ALIYUN_OSS_BUCKET_NAME
|
||||
self.folder = dify_config.ALIYUN_OSS_PATH
|
||||
oss_auth_method = aliyun_s3.Auth
|
||||
region = None
|
||||
if dify_config.ALIYUN_OSS_AUTH_VERSION == "v4":
|
||||
oss_auth_method = aliyun_s3.AuthV4
|
||||
region = dify_config.ALIYUN_OSS_REGION
|
||||
oss_auth = oss_auth_method(dify_config.ALIYUN_OSS_ACCESS_KEY, dify_config.ALIYUN_OSS_SECRET_KEY)
|
||||
self.client = aliyun_s3.Bucket(
|
||||
oss_auth,
|
||||
dify_config.ALIYUN_OSS_ENDPOINT,
|
||||
self.bucket_name,
|
||||
connect_timeout=30,
|
||||
region=region,
|
||||
)
|
||||
|
||||
def save(self, filename, data):
|
||||
self.client.put_object(self.__wrapper_folder_filename(filename), data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
|
||||
data = obj.read()
|
||||
if not isinstance(data, bytes):
|
||||
return b""
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
|
||||
while chunk := obj.read(4096):
|
||||
yield chunk
|
||||
|
||||
def download(self, filename: str, target_filepath):
|
||||
self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath)
|
||||
|
||||
def exists(self, filename: str):
|
||||
return self.client.object_exists(self.__wrapper_folder_filename(filename))
|
||||
|
||||
def delete(self, filename: str):
|
||||
self.client.delete_object(self.__wrapper_folder_filename(filename))
|
||||
|
||||
def __wrapper_folder_filename(self, filename: str) -> str:
|
||||
return posixpath.join(self.folder, filename) if self.folder else filename
|
||||
87
dify/api/extensions/storage/aws_s3_storage.py
Normal file
87
dify/api/extensions/storage/aws_s3_storage.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AwsS3Storage(BaseStorage):
|
||||
"""Implementation for Amazon Web Services S3 storage."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bucket_name = dify_config.S3_BUCKET_NAME
|
||||
if dify_config.S3_USE_AWS_MANAGED_IAM:
|
||||
logger.info("Using AWS managed IAM role for S3")
|
||||
|
||||
session = boto3.Session()
|
||||
region_name = dify_config.S3_REGION
|
||||
self.client = session.client(service_name="s3", region_name=region_name)
|
||||
else:
|
||||
logger.info("Using ak and sk for S3")
|
||||
|
||||
self.client = boto3.client(
|
||||
"s3",
|
||||
aws_secret_access_key=dify_config.S3_SECRET_KEY,
|
||||
aws_access_key_id=dify_config.S3_ACCESS_KEY,
|
||||
endpoint_url=dify_config.S3_ENDPOINT,
|
||||
region_name=dify_config.S3_REGION,
|
||||
config=Config(s3={"addressing_style": dify_config.S3_ADDRESS_STYLE}),
|
||||
)
|
||||
# create bucket
|
||||
try:
|
||||
self.client.head_bucket(Bucket=self.bucket_name)
|
||||
except ClientError as e:
|
||||
# if bucket not exists, create it
|
||||
if e.response.get("Error", {}).get("Code") == "404":
|
||||
self.client.create_bucket(Bucket=self.bucket_name)
|
||||
# if bucket is not accessible, pass, maybe the bucket is existing but not accessible
|
||||
elif e.response.get("Error", {}).get("Code") == "403":
|
||||
pass
|
||||
else:
|
||||
# other error, raise exception
|
||||
raise
|
||||
|
||||
def save(self, filename, data):
|
||||
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
try:
|
||||
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
||||
except ClientError as ex:
|
||||
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
|
||||
raise FileNotFoundError("File not found")
|
||||
else:
|
||||
raise
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
try:
|
||||
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||
yield from response["Body"].iter_chunks()
|
||||
except ClientError as ex:
|
||||
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
|
||||
raise FileNotFoundError("file not found")
|
||||
elif "reached max retries" in str(ex):
|
||||
raise ValueError("please do not request the same file too frequently")
|
||||
else:
|
||||
raise
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
self.client.download_file(self.bucket_name, filename, target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
try:
|
||||
self.client.head_object(Bucket=self.bucket_name, Key=filename)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def delete(self, filename):
|
||||
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
||||
104
dify/api/extensions/storage/azure_blob_storage.py
Normal file
104
dify/api/extensions/storage/azure_blob_storage.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
|
||||
from azure.identity import ChainedTokenCredential, DefaultAzureCredential
|
||||
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class AzureBlobStorage(BaseStorage):
|
||||
"""Implementation for Azure Blob storage."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bucket_name = dify_config.AZURE_BLOB_CONTAINER_NAME
|
||||
self.account_url = dify_config.AZURE_BLOB_ACCOUNT_URL
|
||||
self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME
|
||||
self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY
|
||||
|
||||
self.credential: ChainedTokenCredential | None = None
|
||||
if self.account_key == "managedidentity":
|
||||
self.credential = DefaultAzureCredential()
|
||||
else:
|
||||
self.credential = None
|
||||
|
||||
def save(self, filename, data):
|
||||
if not self.bucket_name:
|
||||
return
|
||||
|
||||
client = self._sync_client()
|
||||
blob_container = client.get_container_client(container=self.bucket_name)
|
||||
blob_container.upload_blob(filename, data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
if not self.bucket_name:
|
||||
raise FileNotFoundError("Azure bucket name is not configured.")
|
||||
|
||||
client = self._sync_client()
|
||||
blob = client.get_container_client(container=self.bucket_name)
|
||||
blob = blob.get_blob_client(blob=filename)
|
||||
data = blob.download_blob().readall()
|
||||
if not isinstance(data, bytes):
|
||||
raise TypeError(f"Expected bytes from blob.readall(), got {type(data).__name__}")
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
if not self.bucket_name:
|
||||
raise FileNotFoundError("Azure bucket name is not configured.")
|
||||
|
||||
client = self._sync_client()
|
||||
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
blob_data = blob.download_blob()
|
||||
yield from blob_data.chunks()
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
if not self.bucket_name:
|
||||
return
|
||||
|
||||
client = self._sync_client()
|
||||
|
||||
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
with open(target_filepath, "wb") as my_blob:
|
||||
blob_data = blob.download_blob()
|
||||
blob_data.readinto(my_blob)
|
||||
|
||||
def exists(self, filename):
|
||||
if not self.bucket_name:
|
||||
return False
|
||||
|
||||
client = self._sync_client()
|
||||
|
||||
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
return blob.exists()
|
||||
|
||||
def delete(self, filename):
|
||||
if not self.bucket_name:
|
||||
return
|
||||
|
||||
client = self._sync_client()
|
||||
|
||||
blob_container = client.get_container_client(container=self.bucket_name)
|
||||
blob_container.delete_blob(filename)
|
||||
|
||||
def _sync_client(self):
|
||||
if self.account_key == "managedidentity":
|
||||
return BlobServiceClient(account_url=self.account_url, credential=self.credential) # type: ignore
|
||||
|
||||
cache_key = f"azure_blob_sas_token_{self.account_name}_{self.account_key}"
|
||||
cache_result = redis_client.get(cache_key)
|
||||
if cache_result is not None:
|
||||
sas_token = cache_result.decode("utf-8")
|
||||
else:
|
||||
sas_token = generate_account_sas(
|
||||
account_name=self.account_name or "",
|
||||
account_key=self.account_key or "",
|
||||
resource_types=ResourceTypes(service=True, container=True, object=True),
|
||||
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
|
||||
expiry=naive_utc_now() + timedelta(hours=1),
|
||||
)
|
||||
redis_client.set(cache_key, sas_token, ex=3000)
|
||||
return BlobServiceClient(account_url=self.account_url or "", credential=sas_token)
|
||||
57
dify/api/extensions/storage/baidu_obs_storage.py
Normal file
57
dify/api/extensions/storage/baidu_obs_storage.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import base64
|
||||
import hashlib
|
||||
from collections.abc import Generator
|
||||
|
||||
from baidubce.auth.bce_credentials import BceCredentials
|
||||
from baidubce.bce_client_configuration import BceClientConfiguration
|
||||
from baidubce.services.bos.bos_client import BosClient
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
||||
class BaiduObsStorage(BaseStorage):
|
||||
"""Implementation for Baidu OBS storage."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bucket_name = dify_config.BAIDU_OBS_BUCKET_NAME
|
||||
client_config = BceClientConfiguration(
|
||||
credentials=BceCredentials(
|
||||
access_key_id=dify_config.BAIDU_OBS_ACCESS_KEY,
|
||||
secret_access_key=dify_config.BAIDU_OBS_SECRET_KEY,
|
||||
),
|
||||
endpoint=dify_config.BAIDU_OBS_ENDPOINT,
|
||||
)
|
||||
|
||||
self.client = BosClient(config=client_config)
|
||||
|
||||
def save(self, filename, data):
|
||||
md5 = hashlib.md5()
|
||||
md5.update(data)
|
||||
content_md5 = base64.standard_b64encode(md5.digest())
|
||||
self.client.put_object(
|
||||
bucket_name=self.bucket_name, key=filename, data=data, content_length=len(data), content_md5=content_md5
|
||||
)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
response = self.client.get_object(bucket_name=self.bucket_name, key=filename)
|
||||
data: bytes = response.data.read()
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data
|
||||
while chunk := response.read(4096):
|
||||
yield chunk
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
self.client.get_object_to_file(bucket_name=self.bucket_name, key=filename, file_name=target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
res = self.client.get_object_meta_data(bucket_name=self.bucket_name, key=filename)
|
||||
if res is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def delete(self, filename):
|
||||
self.client.delete_object(bucket_name=self.bucket_name, key=filename)
|
||||
40
dify/api/extensions/storage/base_storage.py
Normal file
40
dify/api/extensions/storage/base_storage.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Abstract interface for file storage implementations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
class BaseStorage(ABC):
|
||||
"""Interface for file storage."""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, filename: str, data: bytes):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def download(self, filename, target_filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, filename):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, filename):
|
||||
raise NotImplementedError
|
||||
|
||||
def scan(self, path, files=True, directories=False) -> list[str]:
|
||||
"""
|
||||
Scan files and directories in the given path.
|
||||
This method is implemented only in some storage backends.
|
||||
If a storage backend doesn't support scanning, it will raise NotImplementedError.
|
||||
"""
|
||||
raise NotImplementedError("This storage backend doesn't support scanning")
|
||||
@@ -0,0 +1,5 @@
|
||||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
from .clickzetta_volume_storage import ClickZettaVolumeStorage
|
||||
|
||||
__all__ = ["ClickZettaVolumeStorage"]
|
||||
@@ -0,0 +1,528 @@
|
||||
"""ClickZetta Volume Storage Implementation
|
||||
|
||||
This module provides storage backend using ClickZetta Volume functionality.
|
||||
Supports Table Volume, User Volume, and External Volume types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import clickzetta
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
from .volume_permissions import VolumePermissionManager, check_volume_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClickZettaVolumeConfig(BaseModel):
|
||||
"""Configuration for ClickZetta Volume storage."""
|
||||
|
||||
username: str = ""
|
||||
password: str = ""
|
||||
instance: str = ""
|
||||
service: str = "api.clickzetta.com"
|
||||
workspace: str = "quick_start"
|
||||
vcluster: str = "default_ap"
|
||||
schema_name: str = "dify"
|
||||
volume_type: str = "table" # table|user|external
|
||||
volume_name: str | None = None # For external volumes
|
||||
table_prefix: str = "dataset_" # Prefix for table volume names
|
||||
dify_prefix: str = "dify_km" # Directory prefix for User Volume
|
||||
permission_check: bool = True # Enable/disable permission checking
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
"""Validate the configuration values.
|
||||
|
||||
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
||||
then fall back to CLICKZETTA_* environment variables (for vector DB config).
|
||||
"""
|
||||
|
||||
# Helper function to get environment variable with fallback
|
||||
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str:
|
||||
# First try CLICKZETTA_VOLUME_* specific config
|
||||
volume_value = values.get(volume_key.lower().replace("clickzetta_volume_", ""))
|
||||
if volume_value:
|
||||
return str(volume_value)
|
||||
|
||||
# Then try environment variables
|
||||
volume_env = os.getenv(volume_key)
|
||||
if volume_env:
|
||||
return volume_env
|
||||
|
||||
# Fall back to existing CLICKZETTA_* config
|
||||
fallback_env = os.getenv(fallback_key)
|
||||
if fallback_env:
|
||||
return fallback_env
|
||||
|
||||
return default or ""
|
||||
|
||||
# Apply environment variables with fallback to existing CLICKZETTA_* config
|
||||
values.setdefault("username", get_env_with_fallback("CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME"))
|
||||
values.setdefault("password", get_env_with_fallback("CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD"))
|
||||
values.setdefault("instance", get_env_with_fallback("CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE"))
|
||||
values.setdefault(
|
||||
"service", get_env_with_fallback("CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com")
|
||||
)
|
||||
values.setdefault(
|
||||
"workspace", get_env_with_fallback("CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start")
|
||||
)
|
||||
values.setdefault(
|
||||
"vcluster", get_env_with_fallback("CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap")
|
||||
)
|
||||
values.setdefault("schema_name", get_env_with_fallback("CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify"))
|
||||
|
||||
# Volume-specific configurations (no fallback to vector DB config)
|
||||
values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table"))
|
||||
values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME"))
|
||||
values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_"))
|
||||
values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km"))
|
||||
# Temporarily disable permission check feature, set directly to false
|
||||
values.setdefault("permission_check", False)
|
||||
|
||||
# Validate required fields
|
||||
if not values.get("username"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_USERNAME or CLICKZETTA_USERNAME is required")
|
||||
if not values.get("password"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_PASSWORD or CLICKZETTA_PASSWORD is required")
|
||||
if not values.get("instance"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_INSTANCE or CLICKZETTA_INSTANCE is required")
|
||||
|
||||
# Validate volume type
|
||||
volume_type = values["volume_type"]
|
||||
if volume_type not in ["table", "user", "external"]:
|
||||
raise ValueError("CLICKZETTA_VOLUME_TYPE must be one of: table, user, external")
|
||||
|
||||
if volume_type == "external" and not values.get("volume_name"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_NAME is required for external volume type")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class ClickZettaVolumeStorage(BaseStorage):
|
||||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
def __init__(self, config: ClickZettaVolumeConfig):
|
||||
"""Initialize ClickZetta Volume storage.
|
||||
|
||||
Args:
|
||||
config: ClickZetta Volume configuration
|
||||
"""
|
||||
self._config = config
|
||||
self._connection = None
|
||||
self._permission_manager: VolumePermissionManager | None = None
|
||||
self._init_connection()
|
||||
self._init_permission_manager()
|
||||
|
||||
logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type)
|
||||
|
||||
def _init_connection(self):
|
||||
"""Initialize ClickZetta connection."""
|
||||
try:
|
||||
self._connection = clickzetta.connect(
|
||||
username=self._config.username,
|
||||
password=self._config.password,
|
||||
instance=self._config.instance,
|
||||
service=self._config.service,
|
||||
workspace=self._config.workspace,
|
||||
vcluster=self._config.vcluster,
|
||||
schema=self._config.schema_name,
|
||||
)
|
||||
logger.debug("ClickZetta connection established")
|
||||
except Exception:
|
||||
logger.exception("Failed to connect to ClickZetta")
|
||||
raise
|
||||
|
||||
def _init_permission_manager(self):
|
||||
"""Initialize permission manager."""
|
||||
try:
|
||||
self._permission_manager = VolumePermissionManager(
|
||||
self._connection, self._config.volume_type, self._config.volume_name
|
||||
)
|
||||
logger.debug("Permission manager initialized")
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize permission manager")
|
||||
raise
|
||||
|
||||
def _get_volume_path(self, filename: str, dataset_id: str | None = None) -> str:
|
||||
"""Get the appropriate volume path based on volume type."""
|
||||
if self._config.volume_type == "user":
|
||||
# Add dify prefix for User Volume to organize files
|
||||
return f"{self._config.dify_prefix}/{filename}"
|
||||
elif self._config.volume_type == "table":
|
||||
# Check if this should use User Volume (special directories)
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
# Use User Volume with dify prefix for special directories
|
||||
return f"{self._config.dify_prefix}/{filename}"
|
||||
|
||||
if dataset_id:
|
||||
return f"{self._config.table_prefix}{dataset_id}/{filename}"
|
||||
else:
|
||||
# Extract dataset_id from filename if not provided
|
||||
# Format: dataset_id/filename
|
||||
if "/" in filename:
|
||||
return filename
|
||||
else:
|
||||
raise ValueError("dataset_id is required for table volume or filename must include dataset_id/")
|
||||
elif self._config.volume_type == "external":
|
||||
return filename
|
||||
else:
|
||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||
|
||||
def _get_volume_sql_prefix(self, dataset_id: str | None = None) -> str:
|
||||
"""Get SQL prefix for volume operations."""
|
||||
if self._config.volume_type == "user":
|
||||
return "USER VOLUME"
|
||||
elif self._config.volume_type == "table":
|
||||
# For Dify's current file storage pattern, most files are stored in
|
||||
# paths like "upload_files/tenant_id/uuid.ext", "tools/tenant_id/uuid.ext"
|
||||
# These should use USER VOLUME for better compatibility
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
return "USER VOLUME"
|
||||
|
||||
# Only use TABLE VOLUME for actual dataset-specific paths
|
||||
# like "dataset_12345/file.pdf" or paths with dataset_ prefix
|
||||
if dataset_id:
|
||||
table_name = f"{self._config.table_prefix}{dataset_id}"
|
||||
else:
|
||||
# Default table name for generic operations
|
||||
table_name = "default_dataset"
|
||||
return f"TABLE VOLUME {table_name}"
|
||||
elif self._config.volume_type == "external":
|
||||
return f"VOLUME {self._config.volume_name}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||
|
||||
def _execute_sql(self, sql: str, fetch: bool = False):
|
||||
"""Execute SQL command."""
|
||||
try:
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Connection not initialized")
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
if fetch:
|
||||
return cursor.fetchall()
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("SQL execution failed: %s", sql)
|
||||
raise
|
||||
|
||||
def _ensure_table_volume_exists(self, dataset_id: str):
|
||||
"""Ensure table volume exists for the given dataset_id."""
|
||||
if self._config.volume_type != "table" or not dataset_id:
|
||||
return
|
||||
|
||||
# Skip for upload_files and other special directories that use USER VOLUME
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
return
|
||||
|
||||
table_name = f"{self._config.table_prefix}{dataset_id}"
|
||||
|
||||
try:
|
||||
# Check if table exists
|
||||
check_sql = f"SHOW TABLES LIKE '{table_name}'"
|
||||
result = self._execute_sql(check_sql, fetch=True)
|
||||
|
||||
if not result:
|
||||
# Create table with volume
|
||||
create_sql = f"""
|
||||
CREATE TABLE {table_name} (
|
||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||
filename VARCHAR(255) NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
INDEX idx_filename (filename)
|
||||
) WITH VOLUME
|
||||
"""
|
||||
self._execute_sql(create_sql)
|
||||
logger.info("Created table volume: %s", table_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create table volume %s: %s", table_name, e)
|
||||
# Don't raise exception, let the operation continue
|
||||
# The table might exist but not be visible due to permissions
|
||||
|
||||
def save(self, filename: str, data: bytes):
|
||||
"""Save data to ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
data: File content as bytes
|
||||
"""
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
# Ensure table volume exists (for table volumes)
|
||||
if dataset_id:
|
||||
self._ensure_table_volume_exists(dataset_id)
|
||||
|
||||
# Check permissions (if enabled)
|
||||
if self._config.permission_check:
|
||||
# Skip permission check for special directories that use USER VOLUME
|
||||
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
if self._permission_manager is not None:
|
||||
check_volume_permission(self._permission_manager, "save", dataset_id)
|
||||
|
||||
# Write data to temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
temp_file.write(data)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Upload to volume
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{volume_path}'"
|
||||
else:
|
||||
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{filename}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
logger.debug("File %s saved to ClickZetta Volume at path %s", filename, volume_path)
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
Path(temp_file_path).unlink(missing_ok=True)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
"""Load file content from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
"""
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
# Check permissions (if enabled)
|
||||
if self._config.permission_check:
|
||||
# Skip permission check for special directories that use USER VOLUME
|
||||
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
if self._permission_manager is not None:
|
||||
check_volume_permission(self._permission_manager, "load_once", dataset_id)
|
||||
|
||||
# Download to temporary directory
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"GET {volume_prefix} FILE '{volume_path}' TO '{temp_dir}'"
|
||||
else:
|
||||
sql = f"GET {volume_prefix} FILE '{filename}' TO '{temp_dir}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
|
||||
# Find the downloaded file (may be in subdirectories)
|
||||
downloaded_file = None
|
||||
for root, _, files in os.walk(temp_dir):
|
||||
for file in files:
|
||||
if file == filename or file == os.path.basename(filename):
|
||||
downloaded_file = Path(root) / file
|
||||
break
|
||||
if downloaded_file:
|
||||
break
|
||||
|
||||
if not downloaded_file or not downloaded_file.exists():
|
||||
raise FileNotFoundError(f"Downloaded file not found: {filename}")
|
||||
|
||||
content = downloaded_file.read_bytes()
|
||||
|
||||
logger.debug("File %s loaded from ClickZetta Volume", filename)
|
||||
return content
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
"""Load file as stream from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Yields:
|
||||
File content chunks
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
batch_size = 4096
|
||||
stream = BytesIO(content)
|
||||
|
||||
while chunk := stream.read(batch_size):
|
||||
yield chunk
|
||||
|
||||
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
|
||||
|
||||
def download(self, filename: str, target_filepath: str):
|
||||
"""Download file from ClickZetta Volume to local path.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
target_filepath: Local target file path
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
|
||||
with Path(target_filepath).open("wb") as f:
|
||||
f.write(content)
|
||||
|
||||
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
|
||||
|
||||
def exists(self, filename: str) -> bool:
|
||||
"""Check if file exists in ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Returns:
|
||||
True if file exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"LIST {volume_prefix} REGEXP = '^{volume_path}$'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix} REGEXP = '^{filename}$'"
|
||||
|
||||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
exists = len(rows) > 0 if rows else False
|
||||
logger.debug("File %s exists check: %s", filename, exists)
|
||||
return exists
|
||||
except Exception as e:
|
||||
logger.warning("Error checking file existence for %s: %s", filename, e)
|
||||
return False
|
||||
|
||||
def delete(self, filename: str):
|
||||
"""Delete file from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
"""
|
||||
if not self.exists(filename):
|
||||
logger.debug("File %s not found, skip delete", filename)
|
||||
return
|
||||
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"REMOVE {volume_prefix} FILE '{volume_path}'"
|
||||
else:
|
||||
sql = f"REMOVE {volume_prefix} FILE '{filename}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
|
||||
logger.debug("File %s deleted from ClickZetta Volume", filename)
|
||||
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
"""Scan files and directories in ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
path: Path to scan (dataset_id for table volumes)
|
||||
files: Include files in results
|
||||
directories: Include directories in results
|
||||
|
||||
Returns:
|
||||
List of file/directory paths
|
||||
"""
|
||||
try:
|
||||
# For table volumes, path is treated as dataset_id
|
||||
dataset_id = None
|
||||
if self._config.volume_type == "table":
|
||||
dataset_id = path
|
||||
path = "" # Root of the table volume
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# For User Volume, add dify prefix to path
|
||||
if volume_prefix == "USER VOLUME":
|
||||
if path:
|
||||
scan_path = f"{self._config.dify_prefix}/{path}"
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{scan_path}'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{self._config.dify_prefix}'"
|
||||
else:
|
||||
if path:
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{path}'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix}"
|
||||
|
||||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
result = []
|
||||
if rows:
|
||||
for row in rows:
|
||||
file_path = row[0] # relative_path column
|
||||
|
||||
# For User Volume, remove dify prefix from results
|
||||
dify_prefix_with_slash = f"{self._config.dify_prefix}/"
|
||||
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
|
||||
file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
|
||||
|
||||
if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
|
||||
result.append(file_path)
|
||||
|
||||
logger.debug("Scanned %d items in path %s", len(result), path)
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error scanning path %s", path)
|
||||
return []
|
||||
516
dify/api/extensions/storage/clickzetta_volume/file_lifecycle.py
Normal file
516
dify/api/extensions/storage/clickzetta_volume/file_lifecycle.py
Normal file
@@ -0,0 +1,516 @@
|
||||
"""ClickZetta Volume file lifecycle management
|
||||
|
||||
This module provides file lifecycle management features including version control,
|
||||
automatic cleanup, backup and restore.
|
||||
Supports complete lifecycle management for knowledge base files.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import operator
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileStatus(StrEnum):
|
||||
"""File status enumeration"""
|
||||
|
||||
ACTIVE = auto() # Active status
|
||||
ARCHIVED = auto() # Archived
|
||||
DELETED = auto() # Deleted (soft delete)
|
||||
BACKUP = auto() # Backup file
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileMetadata:
|
||||
"""File metadata"""
|
||||
|
||||
filename: str
|
||||
size: int | None
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
version: int | None
|
||||
status: FileStatus
|
||||
checksum: str | None = None
|
||||
tags: dict[str, str] | None = None
|
||||
parent_version: int | None = None
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary format"""
|
||||
data = asdict(self)
|
||||
data["created_at"] = self.created_at.isoformat()
|
||||
data["modified_at"] = self.modified_at.isoformat()
|
||||
data["status"] = self.status.value
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "FileMetadata":
|
||||
"""Create instance from dictionary"""
|
||||
data = data.copy()
|
||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||
data["modified_at"] = datetime.fromisoformat(data["modified_at"])
|
||||
data["status"] = FileStatus(data["status"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class FileLifecycleManager:
|
||||
"""File lifecycle manager"""
|
||||
|
||||
def __init__(self, storage, dataset_id: str | None = None):
|
||||
"""Initialize lifecycle manager
|
||||
|
||||
Args:
|
||||
storage: ClickZetta Volume storage instance
|
||||
dataset_id: Dataset ID (for Table Volume)
|
||||
"""
|
||||
self._storage = storage
|
||||
self._dataset_id = dataset_id
|
||||
self._metadata_file = ".dify_file_metadata.json"
|
||||
self._version_prefix = ".versions/"
|
||||
self._backup_prefix = ".backups/"
|
||||
self._deleted_prefix = ".deleted/"
|
||||
|
||||
# Get permission manager (if exists)
|
||||
self._permission_manager: Any | None = getattr(storage, "_permission_manager", None)
|
||||
|
||||
def save_with_lifecycle(self, filename: str, data: bytes, tags: dict[str, str] | None = None) -> FileMetadata:
|
||||
"""Save file and manage lifecycle
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
data: File content
|
||||
tags: File tags
|
||||
|
||||
Returns:
|
||||
File metadata
|
||||
"""
|
||||
# Permission check
|
||||
if not self._check_permission(filename, "save"):
|
||||
from .volume_permissions import VolumePermissionError
|
||||
|
||||
raise VolumePermissionError(
|
||||
f"Permission denied for lifecycle save operation on file: {filename}",
|
||||
operation="save",
|
||||
volume_type=getattr(self._storage, "_config", {}).get("volume_type", "unknown"),
|
||||
dataset_id=self._dataset_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# 1. Check if old version exists
|
||||
metadata_dict = self._load_metadata()
|
||||
current_metadata = metadata_dict.get(filename)
|
||||
|
||||
# 2. If old version exists, create version backup
|
||||
if current_metadata:
|
||||
self._create_version_backup(filename, current_metadata)
|
||||
|
||||
# 3. Calculate file information
|
||||
now = datetime.now()
|
||||
checksum = self._calculate_checksum(data)
|
||||
new_version = (current_metadata["version"] + 1) if current_metadata else 1
|
||||
|
||||
# 4. Save new file
|
||||
self._storage.save(filename, data)
|
||||
|
||||
# 5. Create metadata
|
||||
created_at = now
|
||||
parent_version = None
|
||||
|
||||
if current_metadata:
|
||||
# If created_at is string, convert to datetime
|
||||
if isinstance(current_metadata["created_at"], str):
|
||||
created_at = datetime.fromisoformat(current_metadata["created_at"])
|
||||
else:
|
||||
created_at = current_metadata["created_at"]
|
||||
parent_version = current_metadata["version"]
|
||||
|
||||
file_metadata = FileMetadata(
|
||||
filename=filename,
|
||||
size=len(data),
|
||||
created_at=created_at,
|
||||
modified_at=now,
|
||||
version=new_version,
|
||||
status=FileStatus.ACTIVE,
|
||||
checksum=checksum,
|
||||
tags=tags or {},
|
||||
parent_version=parent_version,
|
||||
)
|
||||
|
||||
# 6. Update metadata
|
||||
metadata_dict[filename] = file_metadata.to_dict()
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s saved with lifecycle management, version %s", filename, new_version)
|
||||
return file_metadata
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to save file with lifecycle")
|
||||
raise
|
||||
|
||||
def get_file_metadata(self, filename: str) -> FileMetadata | None:
|
||||
"""Get file metadata
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
|
||||
Returns:
|
||||
File metadata, returns None if not exists
|
||||
"""
|
||||
try:
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename in metadata_dict:
|
||||
return FileMetadata.from_dict(metadata_dict[filename])
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Failed to get file metadata for %s", filename)
|
||||
return None
|
||||
|
||||
def list_file_versions(self, filename: str) -> list[FileMetadata]:
|
||||
"""List all versions of a file
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
|
||||
Returns:
|
||||
File version list, sorted by version number
|
||||
"""
|
||||
try:
|
||||
versions = []
|
||||
|
||||
# Get current version
|
||||
current_metadata = self.get_file_metadata(filename)
|
||||
if current_metadata:
|
||||
versions.append(current_metadata)
|
||||
|
||||
# Get historical versions
|
||||
try:
|
||||
version_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||
for file_path in version_files:
|
||||
if file_path.startswith(f"{self._version_prefix}{filename}.v"):
|
||||
# Parse version number
|
||||
version_str = file_path.split(".v")[-1].split(".")[0]
|
||||
try:
|
||||
_ = int(version_str)
|
||||
# Simplified processing here, should actually read metadata from version file
|
||||
# Temporarily create basic metadata information
|
||||
except ValueError:
|
||||
continue
|
||||
except:
|
||||
# If cannot scan version files, only return current version
|
||||
pass
|
||||
|
||||
return sorted(versions, key=lambda x: x.version or 0, reverse=True)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to list file versions for %s", filename)
|
||||
return []
|
||||
|
||||
def restore_version(self, filename: str, version: int) -> bool:
|
||||
"""Restore file to specified version
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
version: Version number to restore
|
||||
|
||||
Returns:
|
||||
Whether restore succeeded
|
||||
"""
|
||||
try:
|
||||
version_filename = f"{self._version_prefix}{filename}.v{version}"
|
||||
|
||||
# Check if version file exists
|
||||
if not self._storage.exists(version_filename):
|
||||
logger.warning("Version %s of %s not found", version, filename)
|
||||
return False
|
||||
|
||||
# Read version file content
|
||||
version_data = self._storage.load_once(version_filename)
|
||||
|
||||
# Save current version as backup
|
||||
current_metadata = self.get_file_metadata(filename)
|
||||
if current_metadata:
|
||||
self._create_version_backup(filename, current_metadata.to_dict())
|
||||
|
||||
# Restore file
|
||||
self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)})
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to restore %s to version %s", filename, version)
|
||||
return False
|
||||
|
||||
def archive_file(self, filename: str) -> bool:
|
||||
"""Archive file
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
|
||||
Returns:
|
||||
Whether archive succeeded
|
||||
"""
|
||||
# Permission check
|
||||
if not self._check_permission(filename, "archive"):
|
||||
logger.warning("Permission denied for archive operation on file: %s", filename)
|
||||
return False
|
||||
|
||||
try:
|
||||
# Update file status to archived
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename not in metadata_dict:
|
||||
logger.warning("File %s not found in metadata", filename)
|
||||
return False
|
||||
|
||||
metadata_dict[filename]["status"] = FileStatus.ARCHIVED
|
||||
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
|
||||
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s archived successfully", filename)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to archive file %s", filename)
|
||||
return False
|
||||
|
||||
def soft_delete_file(self, filename: str) -> bool:
|
||||
"""Soft delete file (move to deleted directory)
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
|
||||
Returns:
|
||||
Whether delete succeeded
|
||||
"""
|
||||
# Permission check
|
||||
if not self._check_permission(filename, "delete"):
|
||||
logger.warning("Permission denied for soft delete operation on file: %s", filename)
|
||||
return False
|
||||
|
||||
try:
|
||||
# Check if file exists
|
||||
if not self._storage.exists(filename):
|
||||
logger.warning("File %s not found", filename)
|
||||
return False
|
||||
|
||||
# Read file content
|
||||
file_data = self._storage.load_once(filename)
|
||||
|
||||
# Move to deleted directory
|
||||
deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
self._storage.save(deleted_filename, file_data)
|
||||
|
||||
# Delete original file
|
||||
self._storage.delete(filename)
|
||||
|
||||
# Update metadata
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename in metadata_dict:
|
||||
metadata_dict[filename]["status"] = FileStatus.DELETED
|
||||
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s soft deleted successfully", filename)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to soft delete file %s", filename)
|
||||
return False
|
||||
|
||||
def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int:
|
||||
"""Cleanup old version files
|
||||
|
||||
Args:
|
||||
max_versions: Maximum number of versions to keep
|
||||
max_age_days: Maximum retention days for version files
|
||||
|
||||
Returns:
|
||||
Number of files cleaned
|
||||
"""
|
||||
try:
|
||||
cleaned_count = 0
|
||||
|
||||
# Get all version files
|
||||
try:
|
||||
all_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||
version_files = [f for f in all_files if f.startswith(self._version_prefix)]
|
||||
|
||||
# Group by file
|
||||
file_versions: dict[str, list[tuple[int, str]]] = {}
|
||||
for version_file in version_files:
|
||||
# Parse filename and version
|
||||
parts = version_file[len(self._version_prefix) :].split(".v")
|
||||
if len(parts) >= 2:
|
||||
base_filename = parts[0]
|
||||
version_part = parts[1].split(".")[0]
|
||||
try:
|
||||
version_num = int(version_part)
|
||||
if base_filename not in file_versions:
|
||||
file_versions[base_filename] = []
|
||||
file_versions[base_filename].append((version_num, version_file))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Cleanup old versions for each file
|
||||
for base_filename, versions in file_versions.items():
|
||||
# Sort by version number
|
||||
versions.sort(key=operator.itemgetter(0), reverse=True)
|
||||
|
||||
# Keep the newest max_versions versions, delete the rest
|
||||
if len(versions) > max_versions:
|
||||
to_delete = versions[max_versions:]
|
||||
for version_num, version_file in to_delete:
|
||||
self._storage.delete(version_file)
|
||||
cleaned_count += 1
|
||||
logger.debug("Cleaned old version: %s", version_file)
|
||||
|
||||
logger.info("Cleaned %d old version files", cleaned_count)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not scan for version files: %s", e)
|
||||
|
||||
return cleaned_count
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to cleanup old versions")
|
||||
return 0
|
||||
|
||||
def get_storage_statistics(self) -> dict[str, Any]:
|
||||
"""Get storage statistics
|
||||
|
||||
Returns:
|
||||
Storage statistics dictionary
|
||||
"""
|
||||
try:
|
||||
metadata_dict = self._load_metadata()
|
||||
|
||||
stats: dict[str, Any] = {
|
||||
"total_files": len(metadata_dict),
|
||||
"active_files": 0,
|
||||
"archived_files": 0,
|
||||
"deleted_files": 0,
|
||||
"total_size": 0,
|
||||
"versions_count": 0,
|
||||
"oldest_file": None,
|
||||
"newest_file": None,
|
||||
}
|
||||
|
||||
oldest_date = None
|
||||
newest_date = None
|
||||
|
||||
for filename, metadata in metadata_dict.items():
|
||||
file_meta = FileMetadata.from_dict(metadata)
|
||||
|
||||
# Count file status
|
||||
if file_meta.status == FileStatus.ACTIVE:
|
||||
stats["active_files"] = (stats["active_files"] or 0) + 1
|
||||
elif file_meta.status == FileStatus.ARCHIVED:
|
||||
stats["archived_files"] = (stats["archived_files"] or 0) + 1
|
||||
elif file_meta.status == FileStatus.DELETED:
|
||||
stats["deleted_files"] = (stats["deleted_files"] or 0) + 1
|
||||
|
||||
# Count size
|
||||
stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0)
|
||||
|
||||
# Count versions
|
||||
stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0)
|
||||
|
||||
# Find newest and oldest files
|
||||
if oldest_date is None or file_meta.created_at < oldest_date:
|
||||
oldest_date = file_meta.created_at
|
||||
stats["oldest_file"] = filename
|
||||
|
||||
if newest_date is None or file_meta.modified_at > newest_date:
|
||||
newest_date = file_meta.modified_at
|
||||
stats["newest_file"] = filename
|
||||
|
||||
return stats
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to get storage statistics")
|
||||
return {}
|
||||
|
||||
def _create_version_backup(self, filename: str, metadata: dict):
|
||||
"""Create version backup"""
|
||||
try:
|
||||
# Read current file content
|
||||
current_data = self._storage.load_once(filename)
|
||||
|
||||
# Save as version file
|
||||
version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}"
|
||||
self._storage.save(version_filename, current_data)
|
||||
|
||||
logger.debug("Created version backup: %s", version_filename)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create version backup for %s: %s", filename, e)
|
||||
|
||||
def _load_metadata(self) -> dict[str, Any]:
|
||||
"""Load metadata file"""
|
||||
try:
|
||||
if self._storage.exists(self._metadata_file):
|
||||
metadata_content = self._storage.load_once(self._metadata_file)
|
||||
result = json.loads(metadata_content.decode("utf-8"))
|
||||
return dict(result) if result else {}
|
||||
else:
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load metadata: %s", e)
|
||||
return {}
|
||||
|
||||
def _save_metadata(self, metadata_dict: dict):
|
||||
"""Save metadata file"""
|
||||
try:
|
||||
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
|
||||
self._storage.save(self._metadata_file, metadata_content.encode("utf-8"))
|
||||
logger.debug("Metadata saved successfully")
|
||||
except Exception:
|
||||
logger.exception("Failed to save metadata")
|
||||
raise
|
||||
|
||||
def _calculate_checksum(self, data: bytes) -> str:
|
||||
"""Calculate file checksum"""
|
||||
import hashlib
|
||||
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
def _check_permission(self, filename: str, operation: str) -> bool:
|
||||
"""Check file operation permission
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
operation: Operation type
|
||||
|
||||
Returns:
|
||||
True if permission granted, False otherwise
|
||||
"""
|
||||
# If no permission manager, allow by default
|
||||
if not self._permission_manager:
|
||||
return True
|
||||
|
||||
try:
|
||||
# Map operation type to permission
|
||||
operation_mapping = {
|
||||
"save": "save",
|
||||
"load": "load_once",
|
||||
"delete": "delete",
|
||||
"archive": "delete", # Archive requires delete permission
|
||||
"restore": "save", # Restore requires write permission
|
||||
"cleanup": "delete", # Cleanup requires delete permission
|
||||
"read": "load_once",
|
||||
"write": "save",
|
||||
}
|
||||
|
||||
mapped_operation = operation_mapping.get(operation, operation)
|
||||
|
||||
# Check permission
|
||||
result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id)
|
||||
return bool(result)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Permission check failed for %s operation %s", filename, operation)
|
||||
# Safe default: deny access when permission check fails
|
||||
return False
|
||||
@@ -0,0 +1,649 @@
|
||||
"""ClickZetta Volume permission management mechanism
|
||||
|
||||
This module provides Volume permission checking, validation and management features.
|
||||
According to ClickZetta's permission model, different Volume types have different permission requirements.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from enum import StrEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VolumePermission(StrEnum):
|
||||
"""Volume permission type enumeration"""
|
||||
|
||||
READ = "SELECT" # Corresponds to ClickZetta's SELECT permission
|
||||
WRITE = "INSERT,UPDATE,DELETE" # Corresponds to ClickZetta's write permissions
|
||||
LIST = "SELECT" # Listing files requires SELECT permission
|
||||
DELETE = "INSERT,UPDATE,DELETE" # Deleting files requires write permissions
|
||||
USAGE = "USAGE" # Basic permission required for External Volume
|
||||
|
||||
|
||||
class VolumePermissionManager:
|
||||
"""Volume permission manager"""
|
||||
|
||||
def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: str | None = None):
|
||||
"""Initialize permission manager
|
||||
|
||||
Args:
|
||||
connection_or_config: ClickZetta connection object or configuration dictionary
|
||||
volume_type: Volume type (user|table|external)
|
||||
volume_name: Volume name (for external volume)
|
||||
"""
|
||||
# Support two initialization methods: connection object or configuration dictionary
|
||||
if isinstance(connection_or_config, dict):
|
||||
# Create connection from configuration dictionary
|
||||
import clickzetta
|
||||
|
||||
config = connection_or_config
|
||||
self._connection = clickzetta.connect(
|
||||
username=config.get("username"),
|
||||
password=config.get("password"),
|
||||
instance=config.get("instance"),
|
||||
service=config.get("service"),
|
||||
workspace=config.get("workspace"),
|
||||
vcluster=config.get("vcluster"),
|
||||
schema=config.get("schema") or config.get("database"),
|
||||
)
|
||||
self._volume_type = config.get("volume_type", volume_type)
|
||||
self._volume_name = config.get("volume_name", volume_name)
|
||||
else:
|
||||
# Use connection object directly
|
||||
self._connection = connection_or_config
|
||||
self._volume_type = volume_type
|
||||
self._volume_name = volume_name
|
||||
|
||||
if not self._connection:
|
||||
raise ValueError("Valid connection or config is required")
|
||||
if not self._volume_type:
|
||||
raise ValueError("volume_type is required")
|
||||
|
||||
self._permission_cache: dict[str, set[str]] = {}
|
||||
self._current_username = None # Will get current username from connection
|
||||
|
||||
def check_permission(self, operation: VolumePermission, dataset_id: str | None = None) -> bool:
|
||||
"""Check if user has permission to perform specific operation
|
||||
|
||||
Args:
|
||||
operation: Type of operation to perform
|
||||
dataset_id: Dataset ID (for table volume)
|
||||
|
||||
Returns:
|
||||
True if user has permission, False otherwise
|
||||
"""
|
||||
try:
|
||||
if self._volume_type == "user":
|
||||
return self._check_user_volume_permission(operation)
|
||||
elif self._volume_type == "table":
|
||||
return self._check_table_volume_permission(operation, dataset_id)
|
||||
elif self._volume_type == "external":
|
||||
return self._check_external_volume_permission(operation)
|
||||
else:
|
||||
logger.warning("Unknown volume type: %s", self._volume_type)
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception("Permission check failed")
|
||||
return False
|
||||
|
||||
def _check_user_volume_permission(self, operation: VolumePermission) -> bool:
|
||||
"""Check User Volume permission
|
||||
|
||||
User Volume permission rules:
|
||||
- User has full permissions on their own User Volume
|
||||
- As long as user can connect to ClickZetta, they have basic User Volume permissions by default
|
||||
- Focus more on connection authentication rather than complex permission checking
|
||||
"""
|
||||
try:
|
||||
# Get current username
|
||||
current_user = self._get_current_username()
|
||||
|
||||
# Check basic connection status
|
||||
with self._connection.cursor() as cursor:
|
||||
# Simple connection test, if query can be executed user has basic permissions
|
||||
cursor.execute("SELECT 1")
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
logger.debug(
|
||||
"User Volume permission check for %s, operation %s: granted (basic connection verified)",
|
||||
current_user,
|
||||
operation.name,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"User Volume permission check failed: cannot verify basic connection for %s", current_user
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception("User Volume permission check failed")
|
||||
# For User Volume, if permission check fails, it might be a configuration issue,
|
||||
# provide friendlier error message
|
||||
logger.info("User Volume permission check failed, but permission checking is disabled in this version")
|
||||
return False
|
||||
|
||||
def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: str | None) -> bool:
|
||||
"""Check Table Volume permission
|
||||
|
||||
Table Volume permission rules:
|
||||
- Table Volume permissions inherit from corresponding table permissions
|
||||
- SELECT permission -> can READ/LIST files
|
||||
- INSERT,UPDATE,DELETE permissions -> can WRITE/DELETE files
|
||||
"""
|
||||
if not dataset_id:
|
||||
logger.warning("dataset_id is required for table volume permission check")
|
||||
return False
|
||||
|
||||
table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id
|
||||
|
||||
try:
|
||||
# Check table permissions
|
||||
permissions = self._get_table_permissions(table_name)
|
||||
required_permissions = set(operation.value.split(","))
|
||||
|
||||
# Check if has all required permissions
|
||||
has_permission = required_permissions.issubset(permissions)
|
||||
|
||||
logger.debug(
|
||||
"Table Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
|
||||
table_name,
|
||||
operation.name,
|
||||
required_permissions,
|
||||
permissions,
|
||||
has_permission,
|
||||
)
|
||||
|
||||
return has_permission
|
||||
|
||||
except Exception:
|
||||
logger.exception("Table volume permission check failed for %s", table_name)
|
||||
return False
|
||||
|
||||
def _check_external_volume_permission(self, operation: VolumePermission) -> bool:
|
||||
"""Check External Volume permission
|
||||
|
||||
External Volume permission rules:
|
||||
- Try to get permissions for External Volume
|
||||
- If permission check fails, perform fallback verification
|
||||
- For development environment, provide more lenient permission checking
|
||||
"""
|
||||
if not self._volume_name:
|
||||
logger.warning("volume_name is required for external volume permission check")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Check External Volume permissions
|
||||
permissions = self._get_external_volume_permissions(self._volume_name)
|
||||
|
||||
# External Volume permission mapping: determine required permissions based on operation type
|
||||
required_permissions = set()
|
||||
|
||||
if operation in [VolumePermission.READ, VolumePermission.LIST]:
|
||||
required_permissions.add("read")
|
||||
elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]:
|
||||
required_permissions.add("write")
|
||||
|
||||
# Check if has all required permissions
|
||||
has_permission = required_permissions.issubset(permissions)
|
||||
|
||||
logger.debug(
|
||||
"External Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
|
||||
self._volume_name,
|
||||
operation.name,
|
||||
required_permissions,
|
||||
permissions,
|
||||
has_permission,
|
||||
)
|
||||
|
||||
# If permission check fails, try fallback verification
|
||||
if not has_permission:
|
||||
logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name)
|
||||
|
||||
# Fallback verification: try listing Volume to verify basic access permissions
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == self._volume_name:
|
||||
logger.info("Fallback verification successful for %s", self._volume_name)
|
||||
return True
|
||||
except Exception as fallback_e:
|
||||
logger.warning("Fallback verification failed for %s: %s", self._volume_name, fallback_e)
|
||||
|
||||
return has_permission
|
||||
|
||||
except Exception:
|
||||
logger.exception("External volume permission check failed for %s", self._volume_name)
|
||||
logger.info("External Volume permission check failed, but permission checking is disabled in this version")
|
||||
return False
|
||||
|
||||
def _get_table_permissions(self, table_name: str) -> set[str]:
|
||||
"""Get user permissions for specified table
|
||||
|
||||
Args:
|
||||
table_name: Table name
|
||||
|
||||
Returns:
|
||||
Set of user permissions for this table
|
||||
"""
|
||||
cache_key = f"table:{table_name}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# Use correct ClickZetta syntax to check current user permissions
|
||||
cursor.execute("SHOW GRANTS")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
# Parse permission results, find permissions for this table
|
||||
for grant in grants:
|
||||
if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...)
|
||||
privilege = grant[0].upper()
|
||||
object_type = grant[1].upper() if len(grant) > 1 else ""
|
||||
object_name = grant[2] if len(grant) > 2 else ""
|
||||
|
||||
# Check if it's permission for this table
|
||||
if (
|
||||
object_type == "TABLE"
|
||||
and object_name == table_name
|
||||
or object_type == "SCHEMA"
|
||||
and object_name in table_name
|
||||
):
|
||||
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
|
||||
if privilege == "ALL":
|
||||
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
|
||||
else:
|
||||
permissions.add(privilege)
|
||||
|
||||
# If no explicit permissions found, try executing a simple query to verify permissions
|
||||
if not permissions:
|
||||
try:
|
||||
cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1")
|
||||
permissions.add("SELECT")
|
||||
except Exception:
|
||||
logger.debug("Cannot query table %s, no SELECT permission", table_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check table permissions for %s: %s", table_name, e)
|
||||
# Safe default: deny access when permission check fails
|
||||
pass
|
||||
|
||||
# Cache permission information
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def _get_current_username(self) -> str:
|
||||
"""Get current username"""
|
||||
if self._current_username:
|
||||
return self._current_username
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SELECT CURRENT_USER()")
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
self._current_username = result[0]
|
||||
return str(self._current_username)
|
||||
except Exception:
|
||||
logger.exception("Failed to get current username")
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _get_user_permissions(self, username: str) -> set[str]:
|
||||
"""Get user's basic permission set"""
|
||||
cache_key = f"user_permissions:{username}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# Use correct ClickZetta syntax to check current user permissions
|
||||
cursor.execute("SHOW GRANTS")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
# Parse permission results, find user's basic permissions
|
||||
for grant in grants:
|
||||
if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...)
|
||||
privilege = grant[0].upper()
|
||||
_ = grant[1].upper() if len(grant) > 1 else ""
|
||||
|
||||
# Collect all relevant permissions
|
||||
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
|
||||
if privilege == "ALL":
|
||||
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
|
||||
else:
|
||||
permissions.add(privilege)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check user permissions for %s: %s", username, e)
|
||||
# Safe default: deny access when permission check fails
|
||||
pass
|
||||
|
||||
# Cache permission information
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def _get_external_volume_permissions(self, volume_name: str) -> set[str]:
|
||||
"""Get user permissions for specified External Volume
|
||||
|
||||
Args:
|
||||
volume_name: External Volume name
|
||||
|
||||
Returns:
|
||||
Set of user permissions for this Volume
|
||||
"""
|
||||
cache_key = f"external_volume:{volume_name}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# Use correct ClickZetta syntax to check Volume permissions
|
||||
logger.info("Checking permissions for volume: %s", volume_name)
|
||||
cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
logger.info("Raw grants result for %s: %s", volume_name, grants)
|
||||
|
||||
# Parse permission results
|
||||
# Format: (granted_type, privilege, conditions, granted_on, object_name, granted_to,
|
||||
# grantee_name, grantor_name, grant_option, granted_time)
|
||||
for grant in grants:
|
||||
logger.info("Processing grant: %s", grant)
|
||||
if len(grant) >= 5:
|
||||
granted_type = grant[0]
|
||||
privilege = grant[1].upper()
|
||||
granted_on = grant[3]
|
||||
object_name = grant[4]
|
||||
|
||||
logger.info(
|
||||
"Grant details - type: %s, privilege: %s, granted_on: %s, object_name: %s",
|
||||
granted_type,
|
||||
privilege,
|
||||
granted_on,
|
||||
object_name,
|
||||
)
|
||||
|
||||
# Check if it's permission for this Volume or hierarchical permission
|
||||
if (
|
||||
granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)
|
||||
) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"):
|
||||
logger.info("Matching grant found for %s", volume_name)
|
||||
|
||||
if "READ" in privilege:
|
||||
permissions.add("read")
|
||||
logger.info("Added READ permission for %s", volume_name)
|
||||
if "WRITE" in privilege:
|
||||
permissions.add("write")
|
||||
logger.info("Added WRITE permission for %s", volume_name)
|
||||
if "ALTER" in privilege:
|
||||
permissions.add("alter")
|
||||
logger.info("Added ALTER permission for %s", volume_name)
|
||||
if privilege == "ALL":
|
||||
permissions.update(["read", "write", "alter"])
|
||||
logger.info("Added ALL permissions for %s", volume_name)
|
||||
|
||||
logger.info("Final permissions for %s: %s", volume_name, permissions)
|
||||
|
||||
# If no explicit permissions found, try viewing Volume list to verify basic permissions
|
||||
if not permissions:
|
||||
try:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == volume_name:
|
||||
permissions.add("read") # At least has read permission
|
||||
logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name)
|
||||
break
|
||||
except Exception:
|
||||
logger.debug("Cannot access volume %s, no basic permission", volume_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check external volume permissions for %s: %s", volume_name, e)
|
||||
# When permission check fails, try basic Volume access verification
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == volume_name:
|
||||
logger.info("Basic volume access verified for %s", volume_name)
|
||||
permissions.add("read")
|
||||
permissions.add("write") # Assume has write permission
|
||||
break
|
||||
except Exception as basic_e:
|
||||
logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e)
|
||||
# Last fallback: assume basic permissions
|
||||
permissions.add("read")
|
||||
|
||||
# Cache permission information
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def clear_permission_cache(self):
|
||||
"""Clear permission cache"""
|
||||
self._permission_cache.clear()
|
||||
logger.debug("Permission cache cleared")
|
||||
|
||||
@property
|
||||
def volume_type(self) -> str | None:
|
||||
"""Get the volume type."""
|
||||
return self._volume_type
|
||||
|
||||
def get_permission_summary(self, dataset_id: str | None = None) -> dict[str, bool]:
|
||||
"""Get permission summary
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID (for table volume)
|
||||
|
||||
Returns:
|
||||
Permission summary dictionary
|
||||
"""
|
||||
summary = {}
|
||||
|
||||
for operation in VolumePermission:
|
||||
summary[operation.name.lower()] = self.check_permission(operation, dataset_id)
|
||||
|
||||
return summary
|
||||
|
||||
def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool:
|
||||
"""Check permission inheritance for file path
|
||||
|
||||
Args:
|
||||
file_path: File path
|
||||
operation: Operation to perform
|
||||
|
||||
Returns:
|
||||
True if user has permission, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Parse file path
|
||||
path_parts = file_path.strip("/").split("/")
|
||||
|
||||
if not path_parts:
|
||||
logger.warning("Invalid file path for permission inheritance check")
|
||||
return False
|
||||
|
||||
# For Table Volume, first layer is dataset_id
|
||||
if self._volume_type == "table":
|
||||
if len(path_parts) < 1:
|
||||
return False
|
||||
|
||||
dataset_id = path_parts[0]
|
||||
|
||||
# Check permissions for dataset
|
||||
has_dataset_permission = self.check_permission(operation, dataset_id)
|
||||
|
||||
if not has_dataset_permission:
|
||||
logger.debug("Permission denied for dataset %s", dataset_id)
|
||||
return False
|
||||
|
||||
# Check path traversal attack
|
||||
if self._contains_path_traversal(file_path):
|
||||
logger.warning("Path traversal attack detected: %s", file_path)
|
||||
return False
|
||||
|
||||
# Check if accessing sensitive directory
|
||||
if self._is_sensitive_path(file_path):
|
||||
logger.warning("Access to sensitive path denied: %s", file_path)
|
||||
return False
|
||||
|
||||
logger.debug("Permission inherited for path %s", file_path)
|
||||
return True
|
||||
|
||||
elif self._volume_type == "user":
|
||||
# User Volume permission inheritance
|
||||
current_user = self._get_current_username()
|
||||
|
||||
# Check if attempting to access other user's directory
|
||||
if len(path_parts) > 1 and path_parts[0] != current_user:
|
||||
logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0])
|
||||
return False
|
||||
|
||||
# Check basic permissions
|
||||
return self.check_permission(operation)
|
||||
|
||||
elif self._volume_type == "external":
|
||||
# External Volume permission inheritance
|
||||
# Check permissions for External Volume
|
||||
return self.check_permission(operation)
|
||||
|
||||
else:
|
||||
logger.warning("Unknown volume type for permission inheritance: %s", self._volume_type)
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception("Permission inheritance check failed")
|
||||
return False
|
||||
|
||||
def _contains_path_traversal(self, file_path: str) -> bool:
|
||||
"""Check if path contains path traversal attack"""
|
||||
# Check common path traversal patterns
|
||||
traversal_patterns = [
|
||||
"../",
|
||||
"..\\",
|
||||
"..%2f",
|
||||
"..%2F",
|
||||
"..%5c",
|
||||
"..%5C",
|
||||
"%2e%2e%2f",
|
||||
"%2e%2e%5c",
|
||||
"....//",
|
||||
"....\\\\",
|
||||
]
|
||||
|
||||
file_path_lower = file_path.lower()
|
||||
|
||||
for pattern in traversal_patterns:
|
||||
if pattern in file_path_lower:
|
||||
return True
|
||||
|
||||
# Check absolute path
|
||||
if file_path.startswith("/") or file_path.startswith("\\"):
|
||||
return True
|
||||
|
||||
# Check Windows drive path
|
||||
if len(file_path) >= 2 and file_path[1] == ":":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_sensitive_path(self, file_path: str) -> bool:
|
||||
"""Check if path is sensitive path"""
|
||||
sensitive_patterns = [
|
||||
"passwd",
|
||||
"shadow",
|
||||
"hosts",
|
||||
"config",
|
||||
"secrets",
|
||||
"private",
|
||||
"key",
|
||||
"certificate",
|
||||
"cert",
|
||||
"ssl",
|
||||
"database",
|
||||
"backup",
|
||||
"dump",
|
||||
"log",
|
||||
"tmp",
|
||||
]
|
||||
|
||||
file_path_lower = file_path.lower()
|
||||
|
||||
return any(pattern in file_path_lower for pattern in sensitive_patterns)
|
||||
|
||||
def validate_operation(self, operation: str, dataset_id: str | None = None) -> bool:
|
||||
"""Validate operation permission
|
||||
|
||||
Args:
|
||||
operation: Operation name (save|load|exists|delete|scan)
|
||||
dataset_id: Dataset ID
|
||||
|
||||
Returns:
|
||||
True if operation is allowed, False otherwise
|
||||
"""
|
||||
operation_mapping = {
|
||||
"save": VolumePermission.WRITE,
|
||||
"load": VolumePermission.READ,
|
||||
"load_once": VolumePermission.READ,
|
||||
"load_stream": VolumePermission.READ,
|
||||
"download": VolumePermission.READ,
|
||||
"exists": VolumePermission.READ,
|
||||
"delete": VolumePermission.DELETE,
|
||||
"scan": VolumePermission.LIST,
|
||||
}
|
||||
|
||||
if operation not in operation_mapping:
|
||||
logger.warning("Unknown operation: %s", operation)
|
||||
return False
|
||||
|
||||
volume_permission = operation_mapping[operation]
|
||||
return self.check_permission(volume_permission, dataset_id)
|
||||
|
||||
|
||||
class VolumePermissionError(Exception):
|
||||
"""Volume permission error exception"""
|
||||
|
||||
def __init__(self, message: str, operation: str, volume_type: str, dataset_id: str | None = None):
|
||||
self.operation = operation
|
||||
self.volume_type = volume_type
|
||||
self.dataset_id = dataset_id
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def check_volume_permission(permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None):
|
||||
"""Permission check decorator function
|
||||
|
||||
Args:
|
||||
permission_manager: Permission manager
|
||||
operation: Operation name
|
||||
dataset_id: Dataset ID
|
||||
|
||||
Raises:
|
||||
VolumePermissionError: If no permission
|
||||
"""
|
||||
if not permission_manager.validate_operation(operation, dataset_id):
|
||||
error_message = f"Permission denied for operation '{operation}' on {permission_manager.volume_type} volume"
|
||||
if dataset_id:
|
||||
error_message += f" (dataset: {dataset_id})"
|
||||
|
||||
raise VolumePermissionError(
|
||||
error_message,
|
||||
operation=operation,
|
||||
volume_type=permission_manager.volume_type or "unknown",
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
66
dify/api/extensions/storage/google_cloud_storage.py
Normal file
66
dify/api/extensions/storage/google_cloud_storage.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from google.cloud import storage as google_cloud_storage # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
||||
class GoogleCloudStorage(BaseStorage):
|
||||
"""Implementation for Google Cloud storage."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.bucket_name = dify_config.GOOGLE_STORAGE_BUCKET_NAME
|
||||
service_account_json_str = dify_config.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64
|
||||
# if service_account_json_str is empty, use Application Default Credentials
|
||||
if service_account_json_str:
|
||||
service_account_json = base64.b64decode(service_account_json_str).decode("utf-8")
|
||||
# convert str to object
|
||||
service_account_obj = json.loads(service_account_json)
|
||||
self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj)
|
||||
else:
|
||||
self.client = google_cloud_storage.Client()
|
||||
|
||||
def save(self, filename, data):
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.blob(filename)
|
||||
with io.BytesIO(data) as stream:
|
||||
blob.upload_from_file(stream)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.get_blob(filename)
|
||||
if blob is None:
|
||||
raise FileNotFoundError("File not found")
|
||||
data: bytes = blob.download_as_bytes()
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.get_blob(filename)
|
||||
if blob is None:
|
||||
raise FileNotFoundError("File not found")
|
||||
with blob.open(mode="rb") as blob_stream:
|
||||
while chunk := blob_stream.read(4096):
|
||||
yield chunk
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.get_blob(filename)
|
||||
if blob is None:
|
||||
raise FileNotFoundError("File not found")
|
||||
blob.download_to_filename(target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.blob(filename)
|
||||
return blob.exists()
|
||||
|
||||
def delete(self, filename):
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
bucket.delete_blob(filename)
|
||||
51
dify/api/extensions/storage/huawei_obs_storage.py
Normal file
51
dify/api/extensions/storage/huawei_obs_storage.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from obs import ObsClient
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
||||
class HuaweiObsStorage(BaseStorage):
|
||||
"""Implementation for Huawei OBS storage."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.bucket_name = dify_config.HUAWEI_OBS_BUCKET_NAME
|
||||
self.client = ObsClient(
|
||||
access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY,
|
||||
secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY,
|
||||
server=dify_config.HUAWEI_OBS_SERVER,
|
||||
)
|
||||
|
||||
def save(self, filename, data):
|
||||
self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response
|
||||
while chunk := response.read(4096):
|
||||
yield chunk
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
self.client.getObject(bucketName=self.bucket_name, objectKey=filename, downloadPath=target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
res = self._get_meta(filename)
|
||||
if res is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def delete(self, filename):
|
||||
self.client.deleteObject(bucketName=self.bucket_name, objectKey=filename)
|
||||
|
||||
def _get_meta(self, filename):
|
||||
res = self.client.getObjectMetadata(bucketName=self.bucket_name, objectKey=filename)
|
||||
if res and res.status and res.status < 300:
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
101
dify/api/extensions/storage/opendal_storage.py
Normal file
101
dify/api/extensions/storage/opendal_storage.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
import opendal
|
||||
from dotenv import dotenv_values
|
||||
from opendal import Operator
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str = "OPENDAL_"):
|
||||
kwargs = {}
|
||||
config_prefix = prefix + scheme.upper() + "_"
|
||||
for key, value in os.environ.items():
|
||||
if key.startswith(config_prefix):
|
||||
kwargs[key[len(config_prefix) :].lower()] = value
|
||||
|
||||
file_env_vars: dict = dotenv_values(env_file_path) or {}
|
||||
for key, value in file_env_vars.items():
|
||||
if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value:
|
||||
kwargs[key[len(config_prefix) :].lower()] = value
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
class OpenDALStorage(BaseStorage):
|
||||
def __init__(self, scheme: str, **kwargs):
|
||||
kwargs = kwargs or _get_opendal_kwargs(scheme=scheme)
|
||||
|
||||
if scheme == "fs":
|
||||
root = kwargs.get("root", "storage")
|
||||
Path(root).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True)
|
||||
self.op = Operator(scheme=scheme, **kwargs).layer(retry_layer)
|
||||
logger.debug("opendal operator created with scheme %s", scheme)
|
||||
logger.debug("added retry layer to opendal operator")
|
||||
|
||||
def save(self, filename: str, data: bytes):
|
||||
self.op.write(path=filename, bs=data)
|
||||
logger.debug("file %s saved", filename)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
if not self.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
content: bytes = self.op.read(path=filename)
|
||||
logger.debug("file %s loaded", filename)
|
||||
return content
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
if not self.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
batch_size = 4096
|
||||
with self.op.open(
|
||||
path=filename,
|
||||
mode="rb",
|
||||
chunck=batch_size,
|
||||
) as file:
|
||||
while chunk := file.read(batch_size):
|
||||
yield chunk
|
||||
logger.debug("file %s loaded as stream", filename)
|
||||
|
||||
def download(self, filename: str, target_filepath: str):
|
||||
if not self.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
Path(target_filepath).write_bytes(self.op.read(path=filename))
|
||||
logger.debug("file %s downloaded to %s", filename, target_filepath)
|
||||
|
||||
def exists(self, filename: str) -> bool:
|
||||
return self.op.exists(path=filename)
|
||||
|
||||
def delete(self, filename: str):
|
||||
if self.exists(filename):
|
||||
self.op.delete(path=filename)
|
||||
logger.debug("file %s deleted", filename)
|
||||
return
|
||||
logger.debug("file %s not found, skip delete", filename)
|
||||
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
if not self.exists(path):
|
||||
raise FileNotFoundError("Path not found")
|
||||
|
||||
all_files = self.op.list(path=path)
|
||||
if files and directories:
|
||||
logger.debug("files and directories on %s scanned", path)
|
||||
return [f.path for f in all_files]
|
||||
if files:
|
||||
logger.debug("files on %s scanned", path)
|
||||
return [f.path for f in all_files if not f.path.endswith("/")]
|
||||
elif directories:
|
||||
logger.debug("directories on %s scanned", path)
|
||||
return [f.path for f in all_files if f.path.endswith("/")]
|
||||
else:
|
||||
raise ValueError("At least one of files or directories must be True")
|
||||
59
dify/api/extensions/storage/oracle_oci_storage.py
Normal file
59
dify/api/extensions/storage/oracle_oci_storage.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
||||
class OracleOCIStorage(BaseStorage):
|
||||
"""Implementation for Oracle OCI storage."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.bucket_name = dify_config.OCI_BUCKET_NAME
|
||||
self.client = boto3.client(
|
||||
"s3",
|
||||
aws_secret_access_key=dify_config.OCI_SECRET_KEY,
|
||||
aws_access_key_id=dify_config.OCI_ACCESS_KEY,
|
||||
endpoint_url=dify_config.OCI_ENDPOINT,
|
||||
region_name=dify_config.OCI_REGION,
|
||||
)
|
||||
|
||||
def save(self, filename, data):
|
||||
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
try:
|
||||
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
||||
except ClientError as ex:
|
||||
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
|
||||
raise FileNotFoundError("File not found")
|
||||
else:
|
||||
raise
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
try:
|
||||
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||
yield from response["Body"].iter_chunks()
|
||||
except ClientError as ex:
|
||||
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
|
||||
raise FileNotFoundError("File not found")
|
||||
else:
|
||||
raise
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
self.client.download_file(self.bucket_name, filename, target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
try:
|
||||
self.client.head_object(Bucket=self.bucket_name, Key=filename)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def delete(self, filename):
|
||||
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
||||
17
dify/api/extensions/storage/storage_type.py
Normal file
17
dify/api/extensions/storage/storage_type.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class StorageType(StrEnum):
|
||||
ALIYUN_OSS = "aliyun-oss"
|
||||
AZURE_BLOB = "azure-blob"
|
||||
BAIDU_OBS = "baidu-obs"
|
||||
CLICKZETTA_VOLUME = "clickzetta-volume"
|
||||
GOOGLE_STORAGE = "google-storage"
|
||||
HUAWEI_OBS = "huawei-obs"
|
||||
LOCAL = "local"
|
||||
OCI_STORAGE = "oci-storage"
|
||||
OPENDAL = "opendal"
|
||||
S3 = "s3"
|
||||
TENCENT_COS = "tencent-cos"
|
||||
VOLCENGINE_TOS = "volcengine-tos"
|
||||
SUPABASE = "supabase"
|
||||
59
dify/api/extensions/storage/supabase_storage.py
Normal file
59
dify/api/extensions/storage/supabase_storage.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import io
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
from supabase import Client
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
||||
class SupabaseStorage(BaseStorage):
|
||||
"""Implementation for supabase obs storage."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if dify_config.SUPABASE_URL is None:
|
||||
raise ValueError("SUPABASE_URL is not set")
|
||||
if dify_config.SUPABASE_API_KEY is None:
|
||||
raise ValueError("SUPABASE_API_KEY is not set")
|
||||
if dify_config.SUPABASE_BUCKET_NAME is None:
|
||||
raise ValueError("SUPABASE_BUCKET_NAME is not set")
|
||||
|
||||
self.bucket_name = dify_config.SUPABASE_BUCKET_NAME
|
||||
self.client = Client(supabase_url=dify_config.SUPABASE_URL, supabase_key=dify_config.SUPABASE_API_KEY)
|
||||
self.create_bucket(id=dify_config.SUPABASE_BUCKET_NAME, bucket_name=dify_config.SUPABASE_BUCKET_NAME)
|
||||
|
||||
def create_bucket(self, id, bucket_name):
|
||||
if not self.bucket_exists():
|
||||
self.client.storage.create_bucket(id=id, name=bucket_name)
|
||||
|
||||
def save(self, filename, data):
|
||||
self.client.storage.from_(self.bucket_name).upload(filename, data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
content: bytes = self.client.storage.from_(self.bucket_name).download(filename)
|
||||
return content
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
result = self.client.storage.from_(self.bucket_name).download(filename)
|
||||
byte_stream = io.BytesIO(result)
|
||||
while chunk := byte_stream.read(4096): # Read in chunks of 4KB
|
||||
yield chunk
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
result = self.client.storage.from_(self.bucket_name).download(filename)
|
||||
Path(target_filepath).write_bytes(result)
|
||||
|
||||
def exists(self, filename):
|
||||
result = self.client.storage.from_(self.bucket_name).list(path=filename)
|
||||
if len(result) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete(self, filename):
|
||||
self.client.storage.from_(self.bucket_name).remove([filename])
|
||||
|
||||
def bucket_exists(self):
|
||||
buckets = self.client.storage.list_buckets()
|
||||
return any(bucket.name == self.bucket_name for bucket in buckets)
|
||||
43
dify/api/extensions/storage/tencent_cos_storage.py
Normal file
43
dify/api/extensions/storage/tencent_cos_storage.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
||||
class TencentCosStorage(BaseStorage):
|
||||
"""Implementation for Tencent Cloud COS storage."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME
|
||||
config = CosConfig(
|
||||
Region=dify_config.TENCENT_COS_REGION,
|
||||
SecretId=dify_config.TENCENT_COS_SECRET_ID,
|
||||
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
|
||||
Scheme=dify_config.TENCENT_COS_SCHEME,
|
||||
)
|
||||
self.client = CosS3Client(config)
|
||||
|
||||
def save(self, filename, data):
|
||||
self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read()
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||
yield from response["Body"].get_stream(chunk_size=4096)
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||
response["Body"].get_stream_to_file(target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
return self.client.object_exists(Bucket=self.bucket_name, Key=filename)
|
||||
|
||||
def delete(self, filename):
|
||||
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
||||
66
dify/api/extensions/storage/volcengine_tos_storage.py
Normal file
66
dify/api/extensions/storage/volcengine_tos_storage.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import tos
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
||||
class VolcengineTosStorage(BaseStorage):
|
||||
"""Implementation for Volcengine TOS storage."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if not dify_config.VOLCENGINE_TOS_ACCESS_KEY:
|
||||
raise ValueError("VOLCENGINE_TOS_ACCESS_KEY is not set")
|
||||
if not dify_config.VOLCENGINE_TOS_SECRET_KEY:
|
||||
raise ValueError("VOLCENGINE_TOS_SECRET_KEY is not set")
|
||||
if not dify_config.VOLCENGINE_TOS_ENDPOINT:
|
||||
raise ValueError("VOLCENGINE_TOS_ENDPOINT is not set")
|
||||
if not dify_config.VOLCENGINE_TOS_REGION:
|
||||
raise ValueError("VOLCENGINE_TOS_REGION is not set")
|
||||
self.bucket_name = dify_config.VOLCENGINE_TOS_BUCKET_NAME
|
||||
self.client = tos.TosClientV2(
|
||||
ak=dify_config.VOLCENGINE_TOS_ACCESS_KEY,
|
||||
sk=dify_config.VOLCENGINE_TOS_SECRET_KEY,
|
||||
endpoint=dify_config.VOLCENGINE_TOS_ENDPOINT,
|
||||
region=dify_config.VOLCENGINE_TOS_REGION,
|
||||
)
|
||||
|
||||
def save(self, filename, data):
|
||||
if not self.bucket_name:
|
||||
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
|
||||
self.client.put_object(bucket=self.bucket_name, key=filename, content=data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
if not self.bucket_name:
|
||||
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
|
||||
data = self.client.get_object(bucket=self.bucket_name, key=filename).read()
|
||||
if not isinstance(data, bytes):
|
||||
raise TypeError(f"Expected bytes, got {type(data).__name__}")
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
if not self.bucket_name:
|
||||
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
|
||||
response = self.client.get_object(bucket=self.bucket_name, key=filename)
|
||||
while chunk := response.read(4096):
|
||||
yield chunk
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
if not self.bucket_name:
|
||||
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
|
||||
self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
if not self.bucket_name:
|
||||
return False
|
||||
res = self.client.head_object(bucket=self.bucket_name, key=filename)
|
||||
if res.status_code != 200:
|
||||
return False
|
||||
return True
|
||||
|
||||
def delete(self, filename):
|
||||
if not self.bucket_name:
|
||||
return
|
||||
self.client.delete_object(bucket=self.bucket_name, key=filename)
|
||||
Reference in New Issue
Block a user