This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

@@ -0,0 +1,56 @@
import posixpath
from collections.abc import Generator
import oss2 as aliyun_s3
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class AliyunOssStorage(BaseStorage):
"""Implementation for Aliyun OSS storage."""
def __init__(self):
super().__init__()
self.bucket_name = dify_config.ALIYUN_OSS_BUCKET_NAME
self.folder = dify_config.ALIYUN_OSS_PATH
oss_auth_method = aliyun_s3.Auth
region = None
if dify_config.ALIYUN_OSS_AUTH_VERSION == "v4":
oss_auth_method = aliyun_s3.AuthV4
region = dify_config.ALIYUN_OSS_REGION
oss_auth = oss_auth_method(dify_config.ALIYUN_OSS_ACCESS_KEY, dify_config.ALIYUN_OSS_SECRET_KEY)
self.client = aliyun_s3.Bucket(
oss_auth,
dify_config.ALIYUN_OSS_ENDPOINT,
self.bucket_name,
connect_timeout=30,
region=region,
)
def save(self, filename, data):
self.client.put_object(self.__wrapper_folder_filename(filename), data)
def load_once(self, filename: str) -> bytes:
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
data = obj.read()
if not isinstance(data, bytes):
return b""
return data
def load_stream(self, filename: str) -> Generator:
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
while chunk := obj.read(4096):
yield chunk
def download(self, filename: str, target_filepath):
self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath)
def exists(self, filename: str):
return self.client.object_exists(self.__wrapper_folder_filename(filename))
def delete(self, filename: str):
self.client.delete_object(self.__wrapper_folder_filename(filename))
def __wrapper_folder_filename(self, filename: str) -> str:
return posixpath.join(self.folder, filename) if self.folder else filename

View File

@@ -0,0 +1,87 @@
import logging
from collections.abc import Generator
import boto3
from botocore.client import Config
from botocore.exceptions import ClientError
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
logger = logging.getLogger(__name__)
class AwsS3Storage(BaseStorage):
"""Implementation for Amazon Web Services S3 storage."""
def __init__(self):
super().__init__()
self.bucket_name = dify_config.S3_BUCKET_NAME
if dify_config.S3_USE_AWS_MANAGED_IAM:
logger.info("Using AWS managed IAM role for S3")
session = boto3.Session()
region_name = dify_config.S3_REGION
self.client = session.client(service_name="s3", region_name=region_name)
else:
logger.info("Using ak and sk for S3")
self.client = boto3.client(
"s3",
aws_secret_access_key=dify_config.S3_SECRET_KEY,
aws_access_key_id=dify_config.S3_ACCESS_KEY,
endpoint_url=dify_config.S3_ENDPOINT,
region_name=dify_config.S3_REGION,
config=Config(s3={"addressing_style": dify_config.S3_ADDRESS_STYLE}),
)
# create bucket
try:
self.client.head_bucket(Bucket=self.bucket_name)
except ClientError as e:
# if bucket not exists, create it
if e.response.get("Error", {}).get("Code") == "404":
self.client.create_bucket(Bucket=self.bucket_name)
# if bucket is not accessible, pass, maybe the bucket is existing but not accessible
elif e.response.get("Error", {}).get("Code") == "403":
pass
else:
# other error, raise exception
raise
def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
def load_once(self, filename: str) -> bytes:
try:
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
except ClientError as ex:
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
raise FileNotFoundError("File not found")
else:
raise
return data
def load_stream(self, filename: str) -> Generator:
try:
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response["Body"].iter_chunks()
except ClientError as ex:
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
raise FileNotFoundError("file not found")
elif "reached max retries" in str(ex):
raise ValueError("please do not request the same file too frequently")
else:
raise
def download(self, filename, target_filepath):
self.client.download_file(self.bucket_name, filename, target_filepath)
def exists(self, filename):
try:
self.client.head_object(Bucket=self.bucket_name, Key=filename)
return True
except:
return False
def delete(self, filename):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)

View File

@@ -0,0 +1,104 @@
from collections.abc import Generator
from datetime import timedelta
from azure.identity import ChainedTokenCredential, DefaultAzureCredential
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
from configs import dify_config
from extensions.ext_redis import redis_client
from extensions.storage.base_storage import BaseStorage
from libs.datetime_utils import naive_utc_now
class AzureBlobStorage(BaseStorage):
"""Implementation for Azure Blob storage."""
def __init__(self):
super().__init__()
self.bucket_name = dify_config.AZURE_BLOB_CONTAINER_NAME
self.account_url = dify_config.AZURE_BLOB_ACCOUNT_URL
self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME
self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY
self.credential: ChainedTokenCredential | None = None
if self.account_key == "managedidentity":
self.credential = DefaultAzureCredential()
else:
self.credential = None
def save(self, filename, data):
if not self.bucket_name:
return
client = self._sync_client()
blob_container = client.get_container_client(container=self.bucket_name)
blob_container.upload_blob(filename, data)
def load_once(self, filename: str) -> bytes:
if not self.bucket_name:
raise FileNotFoundError("Azure bucket name is not configured.")
client = self._sync_client()
blob = client.get_container_client(container=self.bucket_name)
blob = blob.get_blob_client(blob=filename)
data = blob.download_blob().readall()
if not isinstance(data, bytes):
raise TypeError(f"Expected bytes from blob.readall(), got {type(data).__name__}")
return data
def load_stream(self, filename: str) -> Generator:
if not self.bucket_name:
raise FileNotFoundError("Azure bucket name is not configured.")
client = self._sync_client()
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
blob_data = blob.download_blob()
yield from blob_data.chunks()
def download(self, filename, target_filepath):
if not self.bucket_name:
return
client = self._sync_client()
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
with open(target_filepath, "wb") as my_blob:
blob_data = blob.download_blob()
blob_data.readinto(my_blob)
def exists(self, filename):
if not self.bucket_name:
return False
client = self._sync_client()
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
return blob.exists()
def delete(self, filename):
if not self.bucket_name:
return
client = self._sync_client()
blob_container = client.get_container_client(container=self.bucket_name)
blob_container.delete_blob(filename)
def _sync_client(self):
if self.account_key == "managedidentity":
return BlobServiceClient(account_url=self.account_url, credential=self.credential) # type: ignore
cache_key = f"azure_blob_sas_token_{self.account_name}_{self.account_key}"
cache_result = redis_client.get(cache_key)
if cache_result is not None:
sas_token = cache_result.decode("utf-8")
else:
sas_token = generate_account_sas(
account_name=self.account_name or "",
account_key=self.account_key or "",
resource_types=ResourceTypes(service=True, container=True, object=True),
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
expiry=naive_utc_now() + timedelta(hours=1),
)
redis_client.set(cache_key, sas_token, ex=3000)
return BlobServiceClient(account_url=self.account_url or "", credential=sas_token)

View File

@@ -0,0 +1,57 @@
import base64
import hashlib
from collections.abc import Generator
from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration
from baidubce.services.bos.bos_client import BosClient
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class BaiduObsStorage(BaseStorage):
"""Implementation for Baidu OBS storage."""
def __init__(self):
super().__init__()
self.bucket_name = dify_config.BAIDU_OBS_BUCKET_NAME
client_config = BceClientConfiguration(
credentials=BceCredentials(
access_key_id=dify_config.BAIDU_OBS_ACCESS_KEY,
secret_access_key=dify_config.BAIDU_OBS_SECRET_KEY,
),
endpoint=dify_config.BAIDU_OBS_ENDPOINT,
)
self.client = BosClient(config=client_config)
def save(self, filename, data):
md5 = hashlib.md5()
md5.update(data)
content_md5 = base64.standard_b64encode(md5.digest())
self.client.put_object(
bucket_name=self.bucket_name, key=filename, data=data, content_length=len(data), content_md5=content_md5
)
def load_once(self, filename: str) -> bytes:
response = self.client.get_object(bucket_name=self.bucket_name, key=filename)
data: bytes = response.data.read()
return data
def load_stream(self, filename: str) -> Generator:
response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data
while chunk := response.read(4096):
yield chunk
def download(self, filename, target_filepath):
self.client.get_object_to_file(bucket_name=self.bucket_name, key=filename, file_name=target_filepath)
def exists(self, filename):
res = self.client.get_object_meta_data(bucket_name=self.bucket_name, key=filename)
if res is None:
return False
return True
def delete(self, filename):
self.client.delete_object(bucket_name=self.bucket_name, key=filename)

View File

@@ -0,0 +1,40 @@
"""Abstract interface for file storage implementations."""
from abc import ABC, abstractmethod
from collections.abc import Generator
class BaseStorage(ABC):
"""Interface for file storage."""
@abstractmethod
def save(self, filename: str, data: bytes):
raise NotImplementedError
@abstractmethod
def load_once(self, filename: str) -> bytes:
raise NotImplementedError
@abstractmethod
def load_stream(self, filename: str) -> Generator:
raise NotImplementedError
@abstractmethod
def download(self, filename, target_filepath):
raise NotImplementedError
@abstractmethod
def exists(self, filename):
raise NotImplementedError
@abstractmethod
def delete(self, filename):
raise NotImplementedError
def scan(self, path, files=True, directories=False) -> list[str]:
"""
Scan files and directories in the given path.
This method is implemented only in some storage backends.
If a storage backend doesn't support scanning, it will raise NotImplementedError.
"""
raise NotImplementedError("This storage backend doesn't support scanning")

View File

@@ -0,0 +1,5 @@
"""ClickZetta Volume storage implementation."""
from .clickzetta_volume_storage import ClickZettaVolumeStorage
__all__ = ["ClickZettaVolumeStorage"]

View File

@@ -0,0 +1,528 @@
"""ClickZetta Volume Storage Implementation
This module provides storage backend using ClickZetta Volume functionality.
Supports Table Volume, User Volume, and External Volume types.
"""
import logging
import os
import tempfile
from collections.abc import Generator
from io import BytesIO
from pathlib import Path
import clickzetta
from pydantic import BaseModel, model_validator
from extensions.storage.base_storage import BaseStorage
from .volume_permissions import VolumePermissionManager, check_volume_permission
logger = logging.getLogger(__name__)
class ClickZettaVolumeConfig(BaseModel):
"""Configuration for ClickZetta Volume storage."""
username: str = ""
password: str = ""
instance: str = ""
service: str = "api.clickzetta.com"
workspace: str = "quick_start"
vcluster: str = "default_ap"
schema_name: str = "dify"
volume_type: str = "table" # table|user|external
volume_name: str | None = None # For external volumes
table_prefix: str = "dataset_" # Prefix for table volume names
dify_prefix: str = "dify_km" # Directory prefix for User Volume
permission_check: bool = True # Enable/disable permission checking
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
"""Validate the configuration values.
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
then fall back to CLICKZETTA_* environment variables (for vector DB config).
"""
# Helper function to get environment variable with fallback
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str:
# First try CLICKZETTA_VOLUME_* specific config
volume_value = values.get(volume_key.lower().replace("clickzetta_volume_", ""))
if volume_value:
return str(volume_value)
# Then try environment variables
volume_env = os.getenv(volume_key)
if volume_env:
return volume_env
# Fall back to existing CLICKZETTA_* config
fallback_env = os.getenv(fallback_key)
if fallback_env:
return fallback_env
return default or ""
# Apply environment variables with fallback to existing CLICKZETTA_* config
values.setdefault("username", get_env_with_fallback("CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME"))
values.setdefault("password", get_env_with_fallback("CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD"))
values.setdefault("instance", get_env_with_fallback("CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE"))
values.setdefault(
"service", get_env_with_fallback("CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com")
)
values.setdefault(
"workspace", get_env_with_fallback("CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start")
)
values.setdefault(
"vcluster", get_env_with_fallback("CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap")
)
values.setdefault("schema_name", get_env_with_fallback("CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify"))
# Volume-specific configurations (no fallback to vector DB config)
values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table"))
values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME"))
values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_"))
values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km"))
# Temporarily disable permission check feature, set directly to false
values.setdefault("permission_check", False)
# Validate required fields
if not values.get("username"):
raise ValueError("CLICKZETTA_VOLUME_USERNAME or CLICKZETTA_USERNAME is required")
if not values.get("password"):
raise ValueError("CLICKZETTA_VOLUME_PASSWORD or CLICKZETTA_PASSWORD is required")
if not values.get("instance"):
raise ValueError("CLICKZETTA_VOLUME_INSTANCE or CLICKZETTA_INSTANCE is required")
# Validate volume type
volume_type = values["volume_type"]
if volume_type not in ["table", "user", "external"]:
raise ValueError("CLICKZETTA_VOLUME_TYPE must be one of: table, user, external")
if volume_type == "external" and not values.get("volume_name"):
raise ValueError("CLICKZETTA_VOLUME_NAME is required for external volume type")
return values
class ClickZettaVolumeStorage(BaseStorage):
"""ClickZetta Volume storage implementation."""
def __init__(self, config: ClickZettaVolumeConfig):
"""Initialize ClickZetta Volume storage.
Args:
config: ClickZetta Volume configuration
"""
self._config = config
self._connection = None
self._permission_manager: VolumePermissionManager | None = None
self._init_connection()
self._init_permission_manager()
logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type)
def _init_connection(self):
"""Initialize ClickZetta connection."""
try:
self._connection = clickzetta.connect(
username=self._config.username,
password=self._config.password,
instance=self._config.instance,
service=self._config.service,
workspace=self._config.workspace,
vcluster=self._config.vcluster,
schema=self._config.schema_name,
)
logger.debug("ClickZetta connection established")
except Exception:
logger.exception("Failed to connect to ClickZetta")
raise
def _init_permission_manager(self):
"""Initialize permission manager."""
try:
self._permission_manager = VolumePermissionManager(
self._connection, self._config.volume_type, self._config.volume_name
)
logger.debug("Permission manager initialized")
except Exception:
logger.exception("Failed to initialize permission manager")
raise
def _get_volume_path(self, filename: str, dataset_id: str | None = None) -> str:
"""Get the appropriate volume path based on volume type."""
if self._config.volume_type == "user":
# Add dify prefix for User Volume to organize files
return f"{self._config.dify_prefix}/{filename}"
elif self._config.volume_type == "table":
# Check if this should use User Volume (special directories)
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
# Use User Volume with dify prefix for special directories
return f"{self._config.dify_prefix}/{filename}"
if dataset_id:
return f"{self._config.table_prefix}{dataset_id}/{filename}"
else:
# Extract dataset_id from filename if not provided
# Format: dataset_id/filename
if "/" in filename:
return filename
else:
raise ValueError("dataset_id is required for table volume or filename must include dataset_id/")
elif self._config.volume_type == "external":
return filename
else:
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
def _get_volume_sql_prefix(self, dataset_id: str | None = None) -> str:
"""Get SQL prefix for volume operations."""
if self._config.volume_type == "user":
return "USER VOLUME"
elif self._config.volume_type == "table":
# For Dify's current file storage pattern, most files are stored in
# paths like "upload_files/tenant_id/uuid.ext", "tools/tenant_id/uuid.ext"
# These should use USER VOLUME for better compatibility
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
return "USER VOLUME"
# Only use TABLE VOLUME for actual dataset-specific paths
# like "dataset_12345/file.pdf" or paths with dataset_ prefix
if dataset_id:
table_name = f"{self._config.table_prefix}{dataset_id}"
else:
# Default table name for generic operations
table_name = "default_dataset"
return f"TABLE VOLUME {table_name}"
elif self._config.volume_type == "external":
return f"VOLUME {self._config.volume_name}"
else:
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
def _execute_sql(self, sql: str, fetch: bool = False):
"""Execute SQL command."""
try:
if self._connection is None:
raise RuntimeError("Connection not initialized")
with self._connection.cursor() as cursor:
cursor.execute(sql)
if fetch:
return cursor.fetchall()
return None
except Exception:
logger.exception("SQL execution failed: %s", sql)
raise
def _ensure_table_volume_exists(self, dataset_id: str):
"""Ensure table volume exists for the given dataset_id."""
if self._config.volume_type != "table" or not dataset_id:
return
# Skip for upload_files and other special directories that use USER VOLUME
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
return
table_name = f"{self._config.table_prefix}{dataset_id}"
try:
# Check if table exists
check_sql = f"SHOW TABLES LIKE '{table_name}'"
result = self._execute_sql(check_sql, fetch=True)
if not result:
# Create table with volume
create_sql = f"""
CREATE TABLE {table_name} (
id INT PRIMARY KEY AUTO_INCREMENT,
filename VARCHAR(255) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_filename (filename)
) WITH VOLUME
"""
self._execute_sql(create_sql)
logger.info("Created table volume: %s", table_name)
except Exception as e:
logger.warning("Failed to create table volume %s: %s", table_name, e)
# Don't raise exception, let the operation continue
# The table might exist but not be visible due to permissions
def save(self, filename: str, data: bytes):
"""Save data to ClickZetta Volume.
Args:
filename: File path in volume
data: File content as bytes
"""
# Extract dataset_id from filename if present
dataset_id = None
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
filename = parts[1]
# Ensure table volume exists (for table volumes)
if dataset_id:
self._ensure_table_volume_exists(dataset_id)
# Check permissions (if enabled)
if self._config.permission_check:
# Skip permission check for special directories that use USER VOLUME
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
if self._permission_manager is not None:
check_volume_permission(self._permission_manager, "save", dataset_id)
# Write data to temporary file
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(data)
temp_file_path = temp_file.name
try:
# Upload to volume
volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id)
# For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME":
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{volume_path}'"
else:
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{filename}'"
self._execute_sql(sql)
logger.debug("File %s saved to ClickZetta Volume at path %s", filename, volume_path)
finally:
# Clean up temporary file
Path(temp_file_path).unlink(missing_ok=True)
def load_once(self, filename: str) -> bytes:
"""Load file content from ClickZetta Volume.
Args:
filename: File path in volume
Returns:
File content as bytes
"""
# Extract dataset_id from filename if present
dataset_id = None
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
filename = parts[1]
# Check permissions (if enabled)
if self._config.permission_check:
# Skip permission check for special directories that use USER VOLUME
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
if self._permission_manager is not None:
check_volume_permission(self._permission_manager, "load_once", dataset_id)
# Download to temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id)
# For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME":
sql = f"GET {volume_prefix} FILE '{volume_path}' TO '{temp_dir}'"
else:
sql = f"GET {volume_prefix} FILE '{filename}' TO '{temp_dir}'"
self._execute_sql(sql)
# Find the downloaded file (may be in subdirectories)
downloaded_file = None
for root, _, files in os.walk(temp_dir):
for file in files:
if file == filename or file == os.path.basename(filename):
downloaded_file = Path(root) / file
break
if downloaded_file:
break
if not downloaded_file or not downloaded_file.exists():
raise FileNotFoundError(f"Downloaded file not found: {filename}")
content = downloaded_file.read_bytes()
logger.debug("File %s loaded from ClickZetta Volume", filename)
return content
def load_stream(self, filename: str) -> Generator:
"""Load file as stream from ClickZetta Volume.
Args:
filename: File path in volume
Yields:
File content chunks
"""
content = self.load_once(filename)
batch_size = 4096
stream = BytesIO(content)
while chunk := stream.read(batch_size):
yield chunk
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
def download(self, filename: str, target_filepath: str):
"""Download file from ClickZetta Volume to local path.
Args:
filename: File path in volume
target_filepath: Local target file path
"""
content = self.load_once(filename)
with Path(target_filepath).open("wb") as f:
f.write(content)
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
def exists(self, filename: str) -> bool:
"""Check if file exists in ClickZetta Volume.
Args:
filename: File path in volume
Returns:
True if file exists, False otherwise
"""
try:
# Extract dataset_id from filename if present
dataset_id = None
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
filename = parts[1]
volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id)
# For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME":
sql = f"LIST {volume_prefix} REGEXP = '^{volume_path}$'"
else:
sql = f"LIST {volume_prefix} REGEXP = '^{filename}$'"
rows = self._execute_sql(sql, fetch=True)
exists = len(rows) > 0 if rows else False
logger.debug("File %s exists check: %s", filename, exists)
return exists
except Exception as e:
logger.warning("Error checking file existence for %s: %s", filename, e)
return False
def delete(self, filename: str):
"""Delete file from ClickZetta Volume.
Args:
filename: File path in volume
"""
if not self.exists(filename):
logger.debug("File %s not found, skip delete", filename)
return
# Extract dataset_id from filename if present
dataset_id = None
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
filename = parts[1]
volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id)
# For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME":
sql = f"REMOVE {volume_prefix} FILE '{volume_path}'"
else:
sql = f"REMOVE {volume_prefix} FILE '{filename}'"
self._execute_sql(sql)
logger.debug("File %s deleted from ClickZetta Volume", filename)
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
"""Scan files and directories in ClickZetta Volume.
Args:
path: Path to scan (dataset_id for table volumes)
files: Include files in results
directories: Include directories in results
Returns:
List of file/directory paths
"""
try:
# For table volumes, path is treated as dataset_id
dataset_id = None
if self._config.volume_type == "table":
dataset_id = path
path = "" # Root of the table volume
volume_prefix = self._get_volume_sql_prefix(dataset_id)
# For User Volume, add dify prefix to path
if volume_prefix == "USER VOLUME":
if path:
scan_path = f"{self._config.dify_prefix}/{path}"
sql = f"LIST {volume_prefix} SUBDIRECTORY '{scan_path}'"
else:
sql = f"LIST {volume_prefix} SUBDIRECTORY '{self._config.dify_prefix}'"
else:
if path:
sql = f"LIST {volume_prefix} SUBDIRECTORY '{path}'"
else:
sql = f"LIST {volume_prefix}"
rows = self._execute_sql(sql, fetch=True)
result = []
if rows:
for row in rows:
file_path = row[0] # relative_path column
# For User Volume, remove dify prefix from results
dify_prefix_with_slash = f"{self._config.dify_prefix}/"
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
result.append(file_path)
logger.debug("Scanned %d items in path %s", len(result), path)
return result
except Exception:
logger.exception("Error scanning path %s", path)
return []

View File

@@ -0,0 +1,516 @@
"""ClickZetta Volume file lifecycle management
This module provides file lifecycle management features including version control,
automatic cleanup, backup and restore.
Supports complete lifecycle management for knowledge base files.
"""
import json
import logging
import operator
from dataclasses import asdict, dataclass
from datetime import datetime
from enum import StrEnum, auto
from typing import Any
logger = logging.getLogger(__name__)
class FileStatus(StrEnum):
"""File status enumeration"""
ACTIVE = auto() # Active status
ARCHIVED = auto() # Archived
DELETED = auto() # Deleted (soft delete)
BACKUP = auto() # Backup file
@dataclass
class FileMetadata:
"""File metadata"""
filename: str
size: int | None
created_at: datetime
modified_at: datetime
version: int | None
status: FileStatus
checksum: str | None = None
tags: dict[str, str] | None = None
parent_version: int | None = None
def to_dict(self):
"""Convert to dictionary format"""
data = asdict(self)
data["created_at"] = self.created_at.isoformat()
data["modified_at"] = self.modified_at.isoformat()
data["status"] = self.status.value
return data
@classmethod
def from_dict(cls, data: dict) -> "FileMetadata":
"""Create instance from dictionary"""
data = data.copy()
data["created_at"] = datetime.fromisoformat(data["created_at"])
data["modified_at"] = datetime.fromisoformat(data["modified_at"])
data["status"] = FileStatus(data["status"])
return cls(**data)
class FileLifecycleManager:
"""File lifecycle manager"""
def __init__(self, storage, dataset_id: str | None = None):
"""Initialize lifecycle manager
Args:
storage: ClickZetta Volume storage instance
dataset_id: Dataset ID (for Table Volume)
"""
self._storage = storage
self._dataset_id = dataset_id
self._metadata_file = ".dify_file_metadata.json"
self._version_prefix = ".versions/"
self._backup_prefix = ".backups/"
self._deleted_prefix = ".deleted/"
# Get permission manager (if exists)
self._permission_manager: Any | None = getattr(storage, "_permission_manager", None)
def save_with_lifecycle(self, filename: str, data: bytes, tags: dict[str, str] | None = None) -> FileMetadata:
"""Save file and manage lifecycle
Args:
filename: File name
data: File content
tags: File tags
Returns:
File metadata
"""
# Permission check
if not self._check_permission(filename, "save"):
from .volume_permissions import VolumePermissionError
raise VolumePermissionError(
f"Permission denied for lifecycle save operation on file: {filename}",
operation="save",
volume_type=getattr(self._storage, "_config", {}).get("volume_type", "unknown"),
dataset_id=self._dataset_id,
)
try:
# 1. Check if old version exists
metadata_dict = self._load_metadata()
current_metadata = metadata_dict.get(filename)
# 2. If old version exists, create version backup
if current_metadata:
self._create_version_backup(filename, current_metadata)
# 3. Calculate file information
now = datetime.now()
checksum = self._calculate_checksum(data)
new_version = (current_metadata["version"] + 1) if current_metadata else 1
# 4. Save new file
self._storage.save(filename, data)
# 5. Create metadata
created_at = now
parent_version = None
if current_metadata:
# If created_at is string, convert to datetime
if isinstance(current_metadata["created_at"], str):
created_at = datetime.fromisoformat(current_metadata["created_at"])
else:
created_at = current_metadata["created_at"]
parent_version = current_metadata["version"]
file_metadata = FileMetadata(
filename=filename,
size=len(data),
created_at=created_at,
modified_at=now,
version=new_version,
status=FileStatus.ACTIVE,
checksum=checksum,
tags=tags or {},
parent_version=parent_version,
)
# 6. Update metadata
metadata_dict[filename] = file_metadata.to_dict()
self._save_metadata(metadata_dict)
logger.info("File %s saved with lifecycle management, version %s", filename, new_version)
return file_metadata
except Exception:
logger.exception("Failed to save file with lifecycle")
raise
def get_file_metadata(self, filename: str) -> FileMetadata | None:
"""Get file metadata
Args:
filename: File name
Returns:
File metadata, returns None if not exists
"""
try:
metadata_dict = self._load_metadata()
if filename in metadata_dict:
return FileMetadata.from_dict(metadata_dict[filename])
return None
except Exception:
logger.exception("Failed to get file metadata for %s", filename)
return None
def list_file_versions(self, filename: str) -> list[FileMetadata]:
"""List all versions of a file
Args:
filename: File name
Returns:
File version list, sorted by version number
"""
try:
versions = []
# Get current version
current_metadata = self.get_file_metadata(filename)
if current_metadata:
versions.append(current_metadata)
# Get historical versions
try:
version_files = self._storage.scan(self._dataset_id or "", files=True)
for file_path in version_files:
if file_path.startswith(f"{self._version_prefix}{filename}.v"):
# Parse version number
version_str = file_path.split(".v")[-1].split(".")[0]
try:
_ = int(version_str)
# Simplified processing here, should actually read metadata from version file
# Temporarily create basic metadata information
except ValueError:
continue
except:
# If cannot scan version files, only return current version
pass
return sorted(versions, key=lambda x: x.version or 0, reverse=True)
except Exception:
logger.exception("Failed to list file versions for %s", filename)
return []
def restore_version(self, filename: str, version: int) -> bool:
"""Restore file to specified version
Args:
filename: File name
version: Version number to restore
Returns:
Whether restore succeeded
"""
try:
version_filename = f"{self._version_prefix}{filename}.v{version}"
# Check if version file exists
if not self._storage.exists(version_filename):
logger.warning("Version %s of %s not found", version, filename)
return False
# Read version file content
version_data = self._storage.load_once(version_filename)
# Save current version as backup
current_metadata = self.get_file_metadata(filename)
if current_metadata:
self._create_version_backup(filename, current_metadata.to_dict())
# Restore file
self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)})
return True
except Exception:
logger.exception("Failed to restore %s to version %s", filename, version)
return False
def archive_file(self, filename: str) -> bool:
"""Archive file
Args:
filename: File name
Returns:
Whether archive succeeded
"""
# Permission check
if not self._check_permission(filename, "archive"):
logger.warning("Permission denied for archive operation on file: %s", filename)
return False
try:
# Update file status to archived
metadata_dict = self._load_metadata()
if filename not in metadata_dict:
logger.warning("File %s not found in metadata", filename)
return False
metadata_dict[filename]["status"] = FileStatus.ARCHIVED
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
self._save_metadata(metadata_dict)
logger.info("File %s archived successfully", filename)
return True
except Exception:
logger.exception("Failed to archive file %s", filename)
return False
def soft_delete_file(self, filename: str) -> bool:
"""Soft delete file (move to deleted directory)
Args:
filename: File name
Returns:
Whether delete succeeded
"""
# Permission check
if not self._check_permission(filename, "delete"):
logger.warning("Permission denied for soft delete operation on file: %s", filename)
return False
try:
# Check if file exists
if not self._storage.exists(filename):
logger.warning("File %s not found", filename)
return False
# Read file content
file_data = self._storage.load_once(filename)
# Move to deleted directory
deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self._storage.save(deleted_filename, file_data)
# Delete original file
self._storage.delete(filename)
# Update metadata
metadata_dict = self._load_metadata()
if filename in metadata_dict:
metadata_dict[filename]["status"] = FileStatus.DELETED
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
self._save_metadata(metadata_dict)
logger.info("File %s soft deleted successfully", filename)
return True
except Exception:
logger.exception("Failed to soft delete file %s", filename)
return False
def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int:
"""Cleanup old version files
Args:
max_versions: Maximum number of versions to keep
max_age_days: Maximum retention days for version files
Returns:
Number of files cleaned
"""
try:
cleaned_count = 0
# Get all version files
try:
all_files = self._storage.scan(self._dataset_id or "", files=True)
version_files = [f for f in all_files if f.startswith(self._version_prefix)]
# Group by file
file_versions: dict[str, list[tuple[int, str]]] = {}
for version_file in version_files:
# Parse filename and version
parts = version_file[len(self._version_prefix) :].split(".v")
if len(parts) >= 2:
base_filename = parts[0]
version_part = parts[1].split(".")[0]
try:
version_num = int(version_part)
if base_filename not in file_versions:
file_versions[base_filename] = []
file_versions[base_filename].append((version_num, version_file))
except ValueError:
continue
# Cleanup old versions for each file
for base_filename, versions in file_versions.items():
# Sort by version number
versions.sort(key=operator.itemgetter(0), reverse=True)
# Keep the newest max_versions versions, delete the rest
if len(versions) > max_versions:
to_delete = versions[max_versions:]
for version_num, version_file in to_delete:
self._storage.delete(version_file)
cleaned_count += 1
logger.debug("Cleaned old version: %s", version_file)
logger.info("Cleaned %d old version files", cleaned_count)
except Exception as e:
logger.warning("Could not scan for version files: %s", e)
return cleaned_count
except Exception:
logger.exception("Failed to cleanup old versions")
return 0
def get_storage_statistics(self) -> dict[str, Any]:
"""Get storage statistics
Returns:
Storage statistics dictionary
"""
try:
metadata_dict = self._load_metadata()
stats: dict[str, Any] = {
"total_files": len(metadata_dict),
"active_files": 0,
"archived_files": 0,
"deleted_files": 0,
"total_size": 0,
"versions_count": 0,
"oldest_file": None,
"newest_file": None,
}
oldest_date = None
newest_date = None
for filename, metadata in metadata_dict.items():
file_meta = FileMetadata.from_dict(metadata)
# Count file status
if file_meta.status == FileStatus.ACTIVE:
stats["active_files"] = (stats["active_files"] or 0) + 1
elif file_meta.status == FileStatus.ARCHIVED:
stats["archived_files"] = (stats["archived_files"] or 0) + 1
elif file_meta.status == FileStatus.DELETED:
stats["deleted_files"] = (stats["deleted_files"] or 0) + 1
# Count size
stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0)
# Count versions
stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0)
# Find newest and oldest files
if oldest_date is None or file_meta.created_at < oldest_date:
oldest_date = file_meta.created_at
stats["oldest_file"] = filename
if newest_date is None or file_meta.modified_at > newest_date:
newest_date = file_meta.modified_at
stats["newest_file"] = filename
return stats
except Exception:
logger.exception("Failed to get storage statistics")
return {}
def _create_version_backup(self, filename: str, metadata: dict):
"""Create version backup"""
try:
# Read current file content
current_data = self._storage.load_once(filename)
# Save as version file
version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}"
self._storage.save(version_filename, current_data)
logger.debug("Created version backup: %s", version_filename)
except Exception as e:
logger.warning("Failed to create version backup for %s: %s", filename, e)
def _load_metadata(self) -> dict[str, Any]:
"""Load metadata file"""
try:
if self._storage.exists(self._metadata_file):
metadata_content = self._storage.load_once(self._metadata_file)
result = json.loads(metadata_content.decode("utf-8"))
return dict(result) if result else {}
else:
return {}
except Exception as e:
logger.warning("Failed to load metadata: %s", e)
return {}
def _save_metadata(self, metadata_dict: dict):
"""Save metadata file"""
try:
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
self._storage.save(self._metadata_file, metadata_content.encode("utf-8"))
logger.debug("Metadata saved successfully")
except Exception:
logger.exception("Failed to save metadata")
raise
def _calculate_checksum(self, data: bytes) -> str:
"""Calculate file checksum"""
import hashlib
return hashlib.md5(data).hexdigest()
def _check_permission(self, filename: str, operation: str) -> bool:
"""Check file operation permission
Args:
filename: File name
operation: Operation type
Returns:
True if permission granted, False otherwise
"""
# If no permission manager, allow by default
if not self._permission_manager:
return True
try:
# Map operation type to permission
operation_mapping = {
"save": "save",
"load": "load_once",
"delete": "delete",
"archive": "delete", # Archive requires delete permission
"restore": "save", # Restore requires write permission
"cleanup": "delete", # Cleanup requires delete permission
"read": "load_once",
"write": "save",
}
mapped_operation = operation_mapping.get(operation, operation)
# Check permission
result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id)
return bool(result)
except Exception:
logger.exception("Permission check failed for %s operation %s", filename, operation)
# Safe default: deny access when permission check fails
return False

View File

@@ -0,0 +1,649 @@
"""ClickZetta Volume permission management mechanism
This module provides Volume permission checking, validation and management features.
According to ClickZetta's permission model, different Volume types have different permission requirements.
"""
import logging
from enum import StrEnum
logger = logging.getLogger(__name__)
class VolumePermission(StrEnum):
"""Volume permission type enumeration"""
READ = "SELECT" # Corresponds to ClickZetta's SELECT permission
WRITE = "INSERT,UPDATE,DELETE" # Corresponds to ClickZetta's write permissions
LIST = "SELECT" # Listing files requires SELECT permission
DELETE = "INSERT,UPDATE,DELETE" # Deleting files requires write permissions
USAGE = "USAGE" # Basic permission required for External Volume
class VolumePermissionManager:
"""Volume permission manager"""
def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: str | None = None):
"""Initialize permission manager
Args:
connection_or_config: ClickZetta connection object or configuration dictionary
volume_type: Volume type (user|table|external)
volume_name: Volume name (for external volume)
"""
# Support two initialization methods: connection object or configuration dictionary
if isinstance(connection_or_config, dict):
# Create connection from configuration dictionary
import clickzetta
config = connection_or_config
self._connection = clickzetta.connect(
username=config.get("username"),
password=config.get("password"),
instance=config.get("instance"),
service=config.get("service"),
workspace=config.get("workspace"),
vcluster=config.get("vcluster"),
schema=config.get("schema") or config.get("database"),
)
self._volume_type = config.get("volume_type", volume_type)
self._volume_name = config.get("volume_name", volume_name)
else:
# Use connection object directly
self._connection = connection_or_config
self._volume_type = volume_type
self._volume_name = volume_name
if not self._connection:
raise ValueError("Valid connection or config is required")
if not self._volume_type:
raise ValueError("volume_type is required")
self._permission_cache: dict[str, set[str]] = {}
self._current_username = None # Will get current username from connection
def check_permission(self, operation: VolumePermission, dataset_id: str | None = None) -> bool:
"""Check if user has permission to perform specific operation
Args:
operation: Type of operation to perform
dataset_id: Dataset ID (for table volume)
Returns:
True if user has permission, False otherwise
"""
try:
if self._volume_type == "user":
return self._check_user_volume_permission(operation)
elif self._volume_type == "table":
return self._check_table_volume_permission(operation, dataset_id)
elif self._volume_type == "external":
return self._check_external_volume_permission(operation)
else:
logger.warning("Unknown volume type: %s", self._volume_type)
return False
except Exception:
logger.exception("Permission check failed")
return False
def _check_user_volume_permission(self, operation: VolumePermission) -> bool:
"""Check User Volume permission
User Volume permission rules:
- User has full permissions on their own User Volume
- As long as user can connect to ClickZetta, they have basic User Volume permissions by default
- Focus more on connection authentication rather than complex permission checking
"""
try:
# Get current username
current_user = self._get_current_username()
# Check basic connection status
with self._connection.cursor() as cursor:
# Simple connection test, if query can be executed user has basic permissions
cursor.execute("SELECT 1")
result = cursor.fetchone()
if result:
logger.debug(
"User Volume permission check for %s, operation %s: granted (basic connection verified)",
current_user,
operation.name,
)
return True
else:
logger.warning(
"User Volume permission check failed: cannot verify basic connection for %s", current_user
)
return False
except Exception:
logger.exception("User Volume permission check failed")
# For User Volume, if permission check fails, it might be a configuration issue,
# provide friendlier error message
logger.info("User Volume permission check failed, but permission checking is disabled in this version")
return False
def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: str | None) -> bool:
"""Check Table Volume permission
Table Volume permission rules:
- Table Volume permissions inherit from corresponding table permissions
- SELECT permission -> can READ/LIST files
- INSERT,UPDATE,DELETE permissions -> can WRITE/DELETE files
"""
if not dataset_id:
logger.warning("dataset_id is required for table volume permission check")
return False
table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id
try:
# Check table permissions
permissions = self._get_table_permissions(table_name)
required_permissions = set(operation.value.split(","))
# Check if has all required permissions
has_permission = required_permissions.issubset(permissions)
logger.debug(
"Table Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
table_name,
operation.name,
required_permissions,
permissions,
has_permission,
)
return has_permission
except Exception:
logger.exception("Table volume permission check failed for %s", table_name)
return False
def _check_external_volume_permission(self, operation: VolumePermission) -> bool:
"""Check External Volume permission
External Volume permission rules:
- Try to get permissions for External Volume
- If permission check fails, perform fallback verification
- For development environment, provide more lenient permission checking
"""
if not self._volume_name:
logger.warning("volume_name is required for external volume permission check")
return False
try:
# Check External Volume permissions
permissions = self._get_external_volume_permissions(self._volume_name)
# External Volume permission mapping: determine required permissions based on operation type
required_permissions = set()
if operation in [VolumePermission.READ, VolumePermission.LIST]:
required_permissions.add("read")
elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]:
required_permissions.add("write")
# Check if has all required permissions
has_permission = required_permissions.issubset(permissions)
logger.debug(
"External Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
self._volume_name,
operation.name,
required_permissions,
permissions,
has_permission,
)
# If permission check fails, try fallback verification
if not has_permission:
logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name)
# Fallback verification: try listing Volume to verify basic access permissions
try:
with self._connection.cursor() as cursor:
cursor.execute("SHOW VOLUMES")
volumes = cursor.fetchall()
for volume in volumes:
if len(volume) > 0 and volume[0] == self._volume_name:
logger.info("Fallback verification successful for %s", self._volume_name)
return True
except Exception as fallback_e:
logger.warning("Fallback verification failed for %s: %s", self._volume_name, fallback_e)
return has_permission
except Exception:
logger.exception("External volume permission check failed for %s", self._volume_name)
logger.info("External Volume permission check failed, but permission checking is disabled in this version")
return False
def _get_table_permissions(self, table_name: str) -> set[str]:
"""Get user permissions for specified table
Args:
table_name: Table name
Returns:
Set of user permissions for this table
"""
cache_key = f"table:{table_name}"
if cache_key in self._permission_cache:
return self._permission_cache[cache_key]
permissions = set()
try:
with self._connection.cursor() as cursor:
# Use correct ClickZetta syntax to check current user permissions
cursor.execute("SHOW GRANTS")
grants = cursor.fetchall()
# Parse permission results, find permissions for this table
for grant in grants:
if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...)
privilege = grant[0].upper()
object_type = grant[1].upper() if len(grant) > 1 else ""
object_name = grant[2] if len(grant) > 2 else ""
# Check if it's permission for this table
if (
object_type == "TABLE"
and object_name == table_name
or object_type == "SCHEMA"
and object_name in table_name
):
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
if privilege == "ALL":
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
else:
permissions.add(privilege)
# If no explicit permissions found, try executing a simple query to verify permissions
if not permissions:
try:
cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1")
permissions.add("SELECT")
except Exception:
logger.debug("Cannot query table %s, no SELECT permission", table_name)
except Exception as e:
logger.warning("Could not check table permissions for %s: %s", table_name, e)
# Safe default: deny access when permission check fails
pass
# Cache permission information
self._permission_cache[cache_key] = permissions
return permissions
def _get_current_username(self) -> str:
"""Get current username"""
if self._current_username:
return self._current_username
try:
with self._connection.cursor() as cursor:
cursor.execute("SELECT CURRENT_USER()")
result = cursor.fetchone()
if result:
self._current_username = result[0]
return str(self._current_username)
except Exception:
logger.exception("Failed to get current username")
return "unknown"
def _get_user_permissions(self, username: str) -> set[str]:
"""Get user's basic permission set"""
cache_key = f"user_permissions:{username}"
if cache_key in self._permission_cache:
return self._permission_cache[cache_key]
permissions = set()
try:
with self._connection.cursor() as cursor:
# Use correct ClickZetta syntax to check current user permissions
cursor.execute("SHOW GRANTS")
grants = cursor.fetchall()
# Parse permission results, find user's basic permissions
for grant in grants:
if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...)
privilege = grant[0].upper()
_ = grant[1].upper() if len(grant) > 1 else ""
# Collect all relevant permissions
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
if privilege == "ALL":
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
else:
permissions.add(privilege)
except Exception as e:
logger.warning("Could not check user permissions for %s: %s", username, e)
# Safe default: deny access when permission check fails
pass
# Cache permission information
self._permission_cache[cache_key] = permissions
return permissions
def _get_external_volume_permissions(self, volume_name: str) -> set[str]:
"""Get user permissions for specified External Volume
Args:
volume_name: External Volume name
Returns:
Set of user permissions for this Volume
"""
cache_key = f"external_volume:{volume_name}"
if cache_key in self._permission_cache:
return self._permission_cache[cache_key]
permissions = set()
try:
with self._connection.cursor() as cursor:
# Use correct ClickZetta syntax to check Volume permissions
logger.info("Checking permissions for volume: %s", volume_name)
cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}")
grants = cursor.fetchall()
logger.info("Raw grants result for %s: %s", volume_name, grants)
# Parse permission results
# Format: (granted_type, privilege, conditions, granted_on, object_name, granted_to,
# grantee_name, grantor_name, grant_option, granted_time)
for grant in grants:
logger.info("Processing grant: %s", grant)
if len(grant) >= 5:
granted_type = grant[0]
privilege = grant[1].upper()
granted_on = grant[3]
object_name = grant[4]
logger.info(
"Grant details - type: %s, privilege: %s, granted_on: %s, object_name: %s",
granted_type,
privilege,
granted_on,
object_name,
)
# Check if it's permission for this Volume or hierarchical permission
if (
granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)
) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"):
logger.info("Matching grant found for %s", volume_name)
if "READ" in privilege:
permissions.add("read")
logger.info("Added READ permission for %s", volume_name)
if "WRITE" in privilege:
permissions.add("write")
logger.info("Added WRITE permission for %s", volume_name)
if "ALTER" in privilege:
permissions.add("alter")
logger.info("Added ALTER permission for %s", volume_name)
if privilege == "ALL":
permissions.update(["read", "write", "alter"])
logger.info("Added ALL permissions for %s", volume_name)
logger.info("Final permissions for %s: %s", volume_name, permissions)
# If no explicit permissions found, try viewing Volume list to verify basic permissions
if not permissions:
try:
cursor.execute("SHOW VOLUMES")
volumes = cursor.fetchall()
for volume in volumes:
if len(volume) > 0 and volume[0] == volume_name:
permissions.add("read") # At least has read permission
logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name)
break
except Exception:
logger.debug("Cannot access volume %s, no basic permission", volume_name)
except Exception as e:
logger.warning("Could not check external volume permissions for %s: %s", volume_name, e)
# When permission check fails, try basic Volume access verification
try:
with self._connection.cursor() as cursor:
cursor.execute("SHOW VOLUMES")
volumes = cursor.fetchall()
for volume in volumes:
if len(volume) > 0 and volume[0] == volume_name:
logger.info("Basic volume access verified for %s", volume_name)
permissions.add("read")
permissions.add("write") # Assume has write permission
break
except Exception as basic_e:
logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e)
# Last fallback: assume basic permissions
permissions.add("read")
# Cache permission information
self._permission_cache[cache_key] = permissions
return permissions
def clear_permission_cache(self):
"""Clear permission cache"""
self._permission_cache.clear()
logger.debug("Permission cache cleared")
@property
def volume_type(self) -> str | None:
"""Get the volume type."""
return self._volume_type
def get_permission_summary(self, dataset_id: str | None = None) -> dict[str, bool]:
"""Get permission summary
Args:
dataset_id: Dataset ID (for table volume)
Returns:
Permission summary dictionary
"""
summary = {}
for operation in VolumePermission:
summary[operation.name.lower()] = self.check_permission(operation, dataset_id)
return summary
def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool:
"""Check permission inheritance for file path
Args:
file_path: File path
operation: Operation to perform
Returns:
True if user has permission, False otherwise
"""
try:
# Parse file path
path_parts = file_path.strip("/").split("/")
if not path_parts:
logger.warning("Invalid file path for permission inheritance check")
return False
# For Table Volume, first layer is dataset_id
if self._volume_type == "table":
if len(path_parts) < 1:
return False
dataset_id = path_parts[0]
# Check permissions for dataset
has_dataset_permission = self.check_permission(operation, dataset_id)
if not has_dataset_permission:
logger.debug("Permission denied for dataset %s", dataset_id)
return False
# Check path traversal attack
if self._contains_path_traversal(file_path):
logger.warning("Path traversal attack detected: %s", file_path)
return False
# Check if accessing sensitive directory
if self._is_sensitive_path(file_path):
logger.warning("Access to sensitive path denied: %s", file_path)
return False
logger.debug("Permission inherited for path %s", file_path)
return True
elif self._volume_type == "user":
# User Volume permission inheritance
current_user = self._get_current_username()
# Check if attempting to access other user's directory
if len(path_parts) > 1 and path_parts[0] != current_user:
logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0])
return False
# Check basic permissions
return self.check_permission(operation)
elif self._volume_type == "external":
# External Volume permission inheritance
# Check permissions for External Volume
return self.check_permission(operation)
else:
logger.warning("Unknown volume type for permission inheritance: %s", self._volume_type)
return False
except Exception:
logger.exception("Permission inheritance check failed")
return False
def _contains_path_traversal(self, file_path: str) -> bool:
"""Check if path contains path traversal attack"""
# Check common path traversal patterns
traversal_patterns = [
"../",
"..\\",
"..%2f",
"..%2F",
"..%5c",
"..%5C",
"%2e%2e%2f",
"%2e%2e%5c",
"....//",
"....\\\\",
]
file_path_lower = file_path.lower()
for pattern in traversal_patterns:
if pattern in file_path_lower:
return True
# Check absolute path
if file_path.startswith("/") or file_path.startswith("\\"):
return True
# Check Windows drive path
if len(file_path) >= 2 and file_path[1] == ":":
return True
return False
def _is_sensitive_path(self, file_path: str) -> bool:
"""Check if path is sensitive path"""
sensitive_patterns = [
"passwd",
"shadow",
"hosts",
"config",
"secrets",
"private",
"key",
"certificate",
"cert",
"ssl",
"database",
"backup",
"dump",
"log",
"tmp",
]
file_path_lower = file_path.lower()
return any(pattern in file_path_lower for pattern in sensitive_patterns)
def validate_operation(self, operation: str, dataset_id: str | None = None) -> bool:
"""Validate operation permission
Args:
operation: Operation name (save|load|exists|delete|scan)
dataset_id: Dataset ID
Returns:
True if operation is allowed, False otherwise
"""
operation_mapping = {
"save": VolumePermission.WRITE,
"load": VolumePermission.READ,
"load_once": VolumePermission.READ,
"load_stream": VolumePermission.READ,
"download": VolumePermission.READ,
"exists": VolumePermission.READ,
"delete": VolumePermission.DELETE,
"scan": VolumePermission.LIST,
}
if operation not in operation_mapping:
logger.warning("Unknown operation: %s", operation)
return False
volume_permission = operation_mapping[operation]
return self.check_permission(volume_permission, dataset_id)
class VolumePermissionError(Exception):
"""Volume permission error exception"""
def __init__(self, message: str, operation: str, volume_type: str, dataset_id: str | None = None):
self.operation = operation
self.volume_type = volume_type
self.dataset_id = dataset_id
super().__init__(message)
def check_volume_permission(permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None):
"""Permission check decorator function
Args:
permission_manager: Permission manager
operation: Operation name
dataset_id: Dataset ID
Raises:
VolumePermissionError: If no permission
"""
if not permission_manager.validate_operation(operation, dataset_id):
error_message = f"Permission denied for operation '{operation}' on {permission_manager.volume_type} volume"
if dataset_id:
error_message += f" (dataset: {dataset_id})"
raise VolumePermissionError(
error_message,
operation=operation,
volume_type=permission_manager.volume_type or "unknown",
dataset_id=dataset_id,
)

View File

@@ -0,0 +1,66 @@
import base64
import io
import json
from collections.abc import Generator
from google.cloud import storage as google_cloud_storage # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class GoogleCloudStorage(BaseStorage):
"""Implementation for Google Cloud storage."""
def __init__(self):
super().__init__()
self.bucket_name = dify_config.GOOGLE_STORAGE_BUCKET_NAME
service_account_json_str = dify_config.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64
# if service_account_json_str is empty, use Application Default Credentials
if service_account_json_str:
service_account_json = base64.b64decode(service_account_json_str).decode("utf-8")
# convert str to object
service_account_obj = json.loads(service_account_json)
self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj)
else:
self.client = google_cloud_storage.Client()
def save(self, filename, data):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(filename)
with io.BytesIO(data) as stream:
blob.upload_from_file(stream)
def load_once(self, filename: str) -> bytes:
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
if blob is None:
raise FileNotFoundError("File not found")
data: bytes = blob.download_as_bytes()
return data
def load_stream(self, filename: str) -> Generator:
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
if blob is None:
raise FileNotFoundError("File not found")
with blob.open(mode="rb") as blob_stream:
while chunk := blob_stream.read(4096):
yield chunk
def download(self, filename, target_filepath):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
if blob is None:
raise FileNotFoundError("File not found")
blob.download_to_filename(target_filepath)
def exists(self, filename):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(filename)
return blob.exists()
def delete(self, filename):
bucket = self.client.get_bucket(self.bucket_name)
bucket.delete_blob(filename)

View File

@@ -0,0 +1,51 @@
from collections.abc import Generator
from obs import ObsClient
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class HuaweiObsStorage(BaseStorage):
"""Implementation for Huawei OBS storage."""
def __init__(self):
super().__init__()
self.bucket_name = dify_config.HUAWEI_OBS_BUCKET_NAME
self.client = ObsClient(
access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY,
secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY,
server=dify_config.HUAWEI_OBS_SERVER,
)
def save(self, filename, data):
self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data)
def load_once(self, filename: str) -> bytes:
data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
return data
def load_stream(self, filename: str) -> Generator:
response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response
while chunk := response.read(4096):
yield chunk
def download(self, filename, target_filepath):
self.client.getObject(bucketName=self.bucket_name, objectKey=filename, downloadPath=target_filepath)
def exists(self, filename):
res = self._get_meta(filename)
if res is None:
return False
return True
def delete(self, filename):
self.client.deleteObject(bucketName=self.bucket_name, objectKey=filename)
def _get_meta(self, filename):
res = self.client.getObjectMetadata(bucketName=self.bucket_name, objectKey=filename)
if res and res.status and res.status < 300:
return res
else:
return None

View File

@@ -0,0 +1,101 @@
import logging
import os
from collections.abc import Generator
from pathlib import Path
import opendal
from dotenv import dotenv_values
from opendal import Operator
from extensions.storage.base_storage import BaseStorage
logger = logging.getLogger(__name__)
def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str = "OPENDAL_"):
kwargs = {}
config_prefix = prefix + scheme.upper() + "_"
for key, value in os.environ.items():
if key.startswith(config_prefix):
kwargs[key[len(config_prefix) :].lower()] = value
file_env_vars: dict = dotenv_values(env_file_path) or {}
for key, value in file_env_vars.items():
if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value:
kwargs[key[len(config_prefix) :].lower()] = value
return kwargs
class OpenDALStorage(BaseStorage):
def __init__(self, scheme: str, **kwargs):
kwargs = kwargs or _get_opendal_kwargs(scheme=scheme)
if scheme == "fs":
root = kwargs.get("root", "storage")
Path(root).mkdir(parents=True, exist_ok=True)
retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True)
self.op = Operator(scheme=scheme, **kwargs).layer(retry_layer)
logger.debug("opendal operator created with scheme %s", scheme)
logger.debug("added retry layer to opendal operator")
def save(self, filename: str, data: bytes):
self.op.write(path=filename, bs=data)
logger.debug("file %s saved", filename)
def load_once(self, filename: str) -> bytes:
if not self.exists(filename):
raise FileNotFoundError("File not found")
content: bytes = self.op.read(path=filename)
logger.debug("file %s loaded", filename)
return content
def load_stream(self, filename: str) -> Generator:
if not self.exists(filename):
raise FileNotFoundError("File not found")
batch_size = 4096
with self.op.open(
path=filename,
mode="rb",
chunck=batch_size,
) as file:
while chunk := file.read(batch_size):
yield chunk
logger.debug("file %s loaded as stream", filename)
def download(self, filename: str, target_filepath: str):
if not self.exists(filename):
raise FileNotFoundError("File not found")
Path(target_filepath).write_bytes(self.op.read(path=filename))
logger.debug("file %s downloaded to %s", filename, target_filepath)
def exists(self, filename: str) -> bool:
return self.op.exists(path=filename)
def delete(self, filename: str):
if self.exists(filename):
self.op.delete(path=filename)
logger.debug("file %s deleted", filename)
return
logger.debug("file %s not found, skip delete", filename)
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
if not self.exists(path):
raise FileNotFoundError("Path not found")
all_files = self.op.list(path=path)
if files and directories:
logger.debug("files and directories on %s scanned", path)
return [f.path for f in all_files]
if files:
logger.debug("files on %s scanned", path)
return [f.path for f in all_files if not f.path.endswith("/")]
elif directories:
logger.debug("directories on %s scanned", path)
return [f.path for f in all_files if f.path.endswith("/")]
else:
raise ValueError("At least one of files or directories must be True")

View File

@@ -0,0 +1,59 @@
from collections.abc import Generator
import boto3
from botocore.exceptions import ClientError
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class OracleOCIStorage(BaseStorage):
"""Implementation for Oracle OCI storage."""
def __init__(self):
super().__init__()
self.bucket_name = dify_config.OCI_BUCKET_NAME
self.client = boto3.client(
"s3",
aws_secret_access_key=dify_config.OCI_SECRET_KEY,
aws_access_key_id=dify_config.OCI_ACCESS_KEY,
endpoint_url=dify_config.OCI_ENDPOINT,
region_name=dify_config.OCI_REGION,
)
def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
def load_once(self, filename: str) -> bytes:
try:
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
except ClientError as ex:
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
raise FileNotFoundError("File not found")
else:
raise
return data
def load_stream(self, filename: str) -> Generator:
try:
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response["Body"].iter_chunks()
except ClientError as ex:
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
raise FileNotFoundError("File not found")
else:
raise
def download(self, filename, target_filepath):
self.client.download_file(self.bucket_name, filename, target_filepath)
def exists(self, filename):
try:
self.client.head_object(Bucket=self.bucket_name, Key=filename)
return True
except:
return False
def delete(self, filename):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)

View File

@@ -0,0 +1,17 @@
from enum import StrEnum
class StorageType(StrEnum):
ALIYUN_OSS = "aliyun-oss"
AZURE_BLOB = "azure-blob"
BAIDU_OBS = "baidu-obs"
CLICKZETTA_VOLUME = "clickzetta-volume"
GOOGLE_STORAGE = "google-storage"
HUAWEI_OBS = "huawei-obs"
LOCAL = "local"
OCI_STORAGE = "oci-storage"
OPENDAL = "opendal"
S3 = "s3"
TENCENT_COS = "tencent-cos"
VOLCENGINE_TOS = "volcengine-tos"
SUPABASE = "supabase"

View File

@@ -0,0 +1,59 @@
import io
from collections.abc import Generator
from pathlib import Path
from supabase import Client
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class SupabaseStorage(BaseStorage):
"""Implementation for supabase obs storage."""
def __init__(self):
super().__init__()
if dify_config.SUPABASE_URL is None:
raise ValueError("SUPABASE_URL is not set")
if dify_config.SUPABASE_API_KEY is None:
raise ValueError("SUPABASE_API_KEY is not set")
if dify_config.SUPABASE_BUCKET_NAME is None:
raise ValueError("SUPABASE_BUCKET_NAME is not set")
self.bucket_name = dify_config.SUPABASE_BUCKET_NAME
self.client = Client(supabase_url=dify_config.SUPABASE_URL, supabase_key=dify_config.SUPABASE_API_KEY)
self.create_bucket(id=dify_config.SUPABASE_BUCKET_NAME, bucket_name=dify_config.SUPABASE_BUCKET_NAME)
def create_bucket(self, id, bucket_name):
if not self.bucket_exists():
self.client.storage.create_bucket(id=id, name=bucket_name)
def save(self, filename, data):
self.client.storage.from_(self.bucket_name).upload(filename, data)
def load_once(self, filename: str) -> bytes:
content: bytes = self.client.storage.from_(self.bucket_name).download(filename)
return content
def load_stream(self, filename: str) -> Generator:
result = self.client.storage.from_(self.bucket_name).download(filename)
byte_stream = io.BytesIO(result)
while chunk := byte_stream.read(4096): # Read in chunks of 4KB
yield chunk
def download(self, filename, target_filepath):
result = self.client.storage.from_(self.bucket_name).download(filename)
Path(target_filepath).write_bytes(result)
def exists(self, filename):
result = self.client.storage.from_(self.bucket_name).list(path=filename)
if len(result) > 0:
return True
return False
def delete(self, filename):
self.client.storage.from_(self.bucket_name).remove([filename])
def bucket_exists(self):
buckets = self.client.storage.list_buckets()
return any(bucket.name == self.bucket_name for bucket in buckets)

View File

@@ -0,0 +1,43 @@
from collections.abc import Generator
from qcloud_cos import CosConfig, CosS3Client
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class TencentCosStorage(BaseStorage):
"""Implementation for Tencent Cloud COS storage."""
def __init__(self):
super().__init__()
self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME
config = CosConfig(
Region=dify_config.TENCENT_COS_REGION,
SecretId=dify_config.TENCENT_COS_SECRET_ID,
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
Scheme=dify_config.TENCENT_COS_SCHEME,
)
self.client = CosS3Client(config)
def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename)
def load_once(self, filename: str) -> bytes:
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read()
return data
def load_stream(self, filename: str) -> Generator:
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response["Body"].get_stream(chunk_size=4096)
def download(self, filename, target_filepath):
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
response["Body"].get_stream_to_file(target_filepath)
def exists(self, filename):
return self.client.object_exists(Bucket=self.bucket_name, Key=filename)
def delete(self, filename):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)

View File

@@ -0,0 +1,66 @@
from collections.abc import Generator
import tos
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class VolcengineTosStorage(BaseStorage):
"""Implementation for Volcengine TOS storage."""
def __init__(self):
super().__init__()
if not dify_config.VOLCENGINE_TOS_ACCESS_KEY:
raise ValueError("VOLCENGINE_TOS_ACCESS_KEY is not set")
if not dify_config.VOLCENGINE_TOS_SECRET_KEY:
raise ValueError("VOLCENGINE_TOS_SECRET_KEY is not set")
if not dify_config.VOLCENGINE_TOS_ENDPOINT:
raise ValueError("VOLCENGINE_TOS_ENDPOINT is not set")
if not dify_config.VOLCENGINE_TOS_REGION:
raise ValueError("VOLCENGINE_TOS_REGION is not set")
self.bucket_name = dify_config.VOLCENGINE_TOS_BUCKET_NAME
self.client = tos.TosClientV2(
ak=dify_config.VOLCENGINE_TOS_ACCESS_KEY,
sk=dify_config.VOLCENGINE_TOS_SECRET_KEY,
endpoint=dify_config.VOLCENGINE_TOS_ENDPOINT,
region=dify_config.VOLCENGINE_TOS_REGION,
)
def save(self, filename, data):
if not self.bucket_name:
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
self.client.put_object(bucket=self.bucket_name, key=filename, content=data)
def load_once(self, filename: str) -> bytes:
if not self.bucket_name:
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
data = self.client.get_object(bucket=self.bucket_name, key=filename).read()
if not isinstance(data, bytes):
raise TypeError(f"Expected bytes, got {type(data).__name__}")
return data
def load_stream(self, filename: str) -> Generator:
if not self.bucket_name:
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
response = self.client.get_object(bucket=self.bucket_name, key=filename)
while chunk := response.read(4096):
yield chunk
def download(self, filename, target_filepath):
if not self.bucket_name:
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath)
def exists(self, filename):
if not self.bucket_name:
return False
res = self.client.head_object(bucket=self.bucket_name, key=filename)
if res.status_code != 200:
return False
return True
def delete(self, filename):
if not self.bucket_name:
return
self.client.delete_object(bucket=self.bucket_name, key=filename)