dify
This commit is contained in:
0
dify/api/factories/__init__.py
Normal file
0
dify/api/factories/__init__.py
Normal file
15
dify/api/factories/agent_factory.py
Normal file
15
dify/api/factories/agent_factory.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.plugin.impl.agent import PluginAgentClient
|
||||
|
||||
|
||||
def get_plugin_agent_strategy(
|
||||
tenant_id: str, agent_strategy_provider_name: str, agent_strategy_name: str
|
||||
) -> PluginAgentStrategy:
|
||||
# TODO: use contexts to cache the agent provider
|
||||
manager = PluginAgentClient()
|
||||
agent_provider = manager.fetch_agent_strategy_provider(tenant_id, agent_strategy_provider_name)
|
||||
for agent_strategy in agent_provider.declaration.strategies:
|
||||
if agent_strategy.identity.name == agent_strategy_name:
|
||||
return PluginAgentStrategy(tenant_id, agent_strategy, agent_provider.meta.version)
|
||||
|
||||
raise ValueError(f"Agent strategy {agent_strategy_name} not found")
|
||||
565
dify/api/factories/file_factory.py
Normal file
565
dify/api/factories/file_factory.py
Normal file
@@ -0,0 +1,565 @@
|
||||
import mimetypes
|
||||
import os
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.http import parse_options_header
|
||||
|
||||
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
|
||||
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from models import MessageFile, ToolFile, UploadFile
|
||||
|
||||
|
||||
def build_from_message_files(
|
||||
*,
|
||||
message_files: Sequence["MessageFile"],
|
||||
tenant_id: str,
|
||||
config: FileUploadConfig | None = None,
|
||||
) -> Sequence[File]:
|
||||
results = [
|
||||
build_from_message_file(message_file=file, tenant_id=tenant_id, config=config)
|
||||
for file in message_files
|
||||
if file.belongs_to != FileBelongsTo.ASSISTANT
|
||||
]
|
||||
return results
|
||||
|
||||
|
||||
def build_from_message_file(
|
||||
*,
|
||||
message_file: "MessageFile",
|
||||
tenant_id: str,
|
||||
config: FileUploadConfig | None,
|
||||
):
|
||||
mapping = {
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"url": message_file.url,
|
||||
"type": message_file.type,
|
||||
}
|
||||
|
||||
# Only include id if it exists (message_file has been committed to DB)
|
||||
if message_file.id:
|
||||
mapping["id"] = message_file.id
|
||||
|
||||
# Set the correct ID field based on transfer method
|
||||
if message_file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
mapping["tool_file_id"] = message_file.upload_file_id
|
||||
else:
|
||||
mapping["upload_file_id"] = message_file.upload_file_id
|
||||
|
||||
return build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def build_from_mapping(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
config: FileUploadConfig | None = None,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
transfer_method_value = mapping.get("transfer_method")
|
||||
if not transfer_method_value:
|
||||
raise ValueError("transfer_method is required in file mapping")
|
||||
transfer_method = FileTransferMethod.value_of(transfer_method_value)
|
||||
|
||||
build_functions: dict[FileTransferMethod, Callable] = {
|
||||
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
|
||||
FileTransferMethod.REMOTE_URL: _build_from_remote_url,
|
||||
FileTransferMethod.TOOL_FILE: _build_from_tool_file,
|
||||
FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file,
|
||||
}
|
||||
|
||||
build_func = build_functions.get(transfer_method)
|
||||
if not build_func:
|
||||
raise ValueError(f"Invalid file transfer method: {transfer_method}")
|
||||
|
||||
file: File = build_func(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
transfer_method=transfer_method,
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
|
||||
if config and not _is_file_valid_with_config(
|
||||
input_file_type=mapping.get("type", FileType.CUSTOM),
|
||||
file_extension=file.extension or "",
|
||||
file_transfer_method=file.transfer_method,
|
||||
config=config,
|
||||
):
|
||||
raise ValueError(f"File validation failed for file: {file.filename}")
|
||||
|
||||
return file
|
||||
|
||||
|
||||
def build_from_mappings(
|
||||
*,
|
||||
mappings: Sequence[Mapping[str, Any]],
|
||||
config: FileUploadConfig | None = None,
|
||||
tenant_id: str,
|
||||
strict_type_validation: bool = False,
|
||||
) -> Sequence[File]:
|
||||
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
|
||||
# Implement batch processing to reduce database load when handling multiple files.
|
||||
# Filter out None/empty mappings to avoid errors
|
||||
valid_mappings = [m for m in mappings if m and m.get("transfer_method")]
|
||||
files = [
|
||||
build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
for mapping in valid_mappings
|
||||
]
|
||||
|
||||
if (
|
||||
config
|
||||
# If image config is set.
|
||||
and config.image_config
|
||||
# And the number of image files exceeds the maximum limit
|
||||
and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits
|
||||
):
|
||||
raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}")
|
||||
if config and config.number_limits and len(files) > config.number_limits:
|
||||
raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}")
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def _build_from_local_file(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
upload_file_id = mapping.get("upload_file_id")
|
||||
if not upload_file_id:
|
||||
raise ValueError("Invalid upload file id")
|
||||
# check if upload_file_id is a valid uuid
|
||||
try:
|
||||
uuid.UUID(upload_file_id)
|
||||
except ValueError:
|
||||
raise ValueError("Invalid upload file id format")
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == upload_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
|
||||
row = db.session.scalar(stmt)
|
||||
if row is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
||||
specified_type = mapping.get("type", "custom")
|
||||
|
||||
if strict_type_validation and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
if specified_type and specified_type != "custom":
|
||||
file_type = FileType(specified_type)
|
||||
else:
|
||||
file_type = detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=row.name,
|
||||
extension="." + row.extension,
|
||||
mime_type=row.mime_type,
|
||||
tenant_id=tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=row.source_url,
|
||||
related_id=mapping.get("upload_file_id"),
|
||||
size=row.size,
|
||||
storage_key=row.key,
|
||||
)
|
||||
|
||||
|
||||
def _build_from_remote_url(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
upload_file_id = mapping.get("upload_file_id")
|
||||
if upload_file_id:
|
||||
try:
|
||||
uuid.UUID(upload_file_id)
|
||||
except ValueError:
|
||||
raise ValueError("Invalid upload file id format")
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == upload_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
|
||||
upload_file = db.session.scalar(stmt)
|
||||
if upload_file is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
detected_file_type = _standardize_file_type(
|
||||
extension="." + upload_file.extension, mime_type=upload_file.mime_type
|
||||
)
|
||||
|
||||
specified_type = mapping.get("type")
|
||||
|
||||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
if specified_type and specified_type != "custom":
|
||||
file_type = FileType(specified_type)
|
||||
else:
|
||||
file_type = detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
|
||||
related_id=mapping.get("upload_file_id"),
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
url = mapping.get("url") or mapping.get("remote_url")
|
||||
if not url:
|
||||
raise ValueError("Invalid file url")
|
||||
|
||||
mime_type, filename, file_size = _get_remote_file_info(url)
|
||||
extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin")
|
||||
|
||||
detected_file_type = _standardize_file_type(extension=extension, mime_type=mime_type)
|
||||
specified_type = mapping.get("type")
|
||||
|
||||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
if specified_type and specified_type != "custom":
|
||||
file_type = FileType(specified_type)
|
||||
else:
|
||||
file_type = detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=filename,
|
||||
tenant_id=tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
mime_type=mime_type,
|
||||
extension=extension,
|
||||
size=file_size,
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
|
||||
def _extract_filename(url_path: str, content_disposition: str | None) -> str | None:
|
||||
filename = None
|
||||
# Try to extract from Content-Disposition header first
|
||||
if content_disposition:
|
||||
_, params = parse_options_header(content_disposition)
|
||||
# RFC 5987 https://datatracker.ietf.org/doc/html/rfc5987: filename* takes precedence over filename
|
||||
filename = params.get("filename*") or params.get("filename")
|
||||
# Fallback to URL path if no filename from header
|
||||
if not filename:
|
||||
filename = os.path.basename(url_path)
|
||||
return filename or None
|
||||
|
||||
|
||||
def _guess_mime_type(filename: str) -> str:
|
||||
"""Guess MIME type from filename, returning empty string if None."""
|
||||
guessed_mime, _ = mimetypes.guess_type(filename)
|
||||
return guessed_mime or ""
|
||||
|
||||
|
||||
def _get_remote_file_info(url: str):
|
||||
file_size = -1
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
filename = os.path.basename(url_path)
|
||||
|
||||
# Initialize mime_type from filename as fallback
|
||||
mime_type = _guess_mime_type(filename)
|
||||
|
||||
resp = ssrf_proxy.head(url, follow_redirects=True)
|
||||
if resp.status_code == httpx.codes.OK:
|
||||
content_disposition = resp.headers.get("Content-Disposition")
|
||||
extracted_filename = _extract_filename(url_path, content_disposition)
|
||||
if extracted_filename:
|
||||
filename = extracted_filename
|
||||
mime_type = _guess_mime_type(filename)
|
||||
file_size = int(resp.headers.get("Content-Length", file_size))
|
||||
# Fallback to Content-Type header if mime_type is still empty
|
||||
if not mime_type:
|
||||
mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
|
||||
if not filename:
|
||||
extension = mimetypes.guess_extension(mime_type) or ".bin"
|
||||
filename = f"{uuid.uuid4().hex}{extension}"
|
||||
if not mime_type:
|
||||
mime_type = _guess_mime_type(filename)
|
||||
|
||||
return mime_type, filename, file_size
|
||||
|
||||
|
||||
def _build_from_tool_file(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
tool_file = db.session.scalar(
|
||||
select(ToolFile).where(
|
||||
ToolFile.id == mapping.get("tool_file_id"),
|
||||
ToolFile.tenant_id == tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_file is None:
|
||||
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
|
||||
|
||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||
|
||||
detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
|
||||
|
||||
specified_type = mapping.get("type")
|
||||
|
||||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
if specified_type and specified_type != "custom":
|
||||
file_type = FileType(specified_type)
|
||||
else:
|
||||
file_type = detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
tenant_id=tenant_id,
|
||||
filename=tool_file.name,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=tool_file.original_url,
|
||||
related_id=tool_file.id,
|
||||
extension=extension,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
|
||||
def _build_from_datasource_file(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
datasource_file = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
UploadFile.id == mapping.get("datasource_file_id"),
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if datasource_file is None:
|
||||
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
|
||||
|
||||
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
|
||||
|
||||
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
|
||||
|
||||
specified_type = mapping.get("type")
|
||||
|
||||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
if specified_type and specified_type != "custom":
|
||||
file_type = FileType(specified_type)
|
||||
else:
|
||||
file_type = detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("datasource_file_id"),
|
||||
tenant_id=tenant_id,
|
||||
filename=datasource_file.name,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
remote_url=datasource_file.source_url,
|
||||
related_id=datasource_file.id,
|
||||
extension=extension,
|
||||
mime_type=datasource_file.mime_type,
|
||||
size=datasource_file.size,
|
||||
storage_key=datasource_file.key,
|
||||
url=datasource_file.source_url,
|
||||
)
|
||||
|
||||
|
||||
def _is_file_valid_with_config(
|
||||
*,
|
||||
input_file_type: str,
|
||||
file_extension: str,
|
||||
file_transfer_method: FileTransferMethod,
|
||||
config: FileUploadConfig,
|
||||
) -> bool:
|
||||
# FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model)
|
||||
# These are internally generated and should bypass user upload restrictions
|
||||
if file_transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
return True
|
||||
|
||||
if (
|
||||
config.allowed_file_types
|
||||
and input_file_type not in config.allowed_file_types
|
||||
and input_file_type != FileType.CUSTOM
|
||||
):
|
||||
return False
|
||||
|
||||
if (
|
||||
input_file_type == FileType.CUSTOM
|
||||
and config.allowed_file_extensions is not None
|
||||
and file_extension not in config.allowed_file_extensions
|
||||
):
|
||||
return False
|
||||
|
||||
if input_file_type == FileType.IMAGE:
|
||||
if (
|
||||
config.image_config
|
||||
and config.image_config.transfer_methods
|
||||
and file_transfer_method not in config.image_config.transfer_methods
|
||||
):
|
||||
return False
|
||||
elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType:
|
||||
"""
|
||||
Infer the possible actual type of the file based on the extension and mime_type
|
||||
"""
|
||||
guessed_type = None
|
||||
if extension:
|
||||
guessed_type = _get_file_type_by_extension(extension)
|
||||
if guessed_type is None and mime_type:
|
||||
guessed_type = _get_file_type_by_mimetype(mime_type)
|
||||
return guessed_type or FileType.CUSTOM
|
||||
|
||||
|
||||
def _get_file_type_by_extension(extension: str) -> FileType | None:
|
||||
extension = extension.lstrip(".")
|
||||
if extension in IMAGE_EXTENSIONS:
|
||||
return FileType.IMAGE
|
||||
elif extension in VIDEO_EXTENSIONS:
|
||||
return FileType.VIDEO
|
||||
elif extension in AUDIO_EXTENSIONS:
|
||||
return FileType.AUDIO
|
||||
elif extension in DOCUMENT_EXTENSIONS:
|
||||
return FileType.DOCUMENT
|
||||
return None
|
||||
|
||||
|
||||
def _get_file_type_by_mimetype(mime_type: str) -> FileType | None:
|
||||
if "image" in mime_type:
|
||||
file_type = FileType.IMAGE
|
||||
elif "video" in mime_type:
|
||||
file_type = FileType.VIDEO
|
||||
elif "audio" in mime_type:
|
||||
file_type = FileType.AUDIO
|
||||
elif "text" in mime_type or "pdf" in mime_type:
|
||||
file_type = FileType.DOCUMENT
|
||||
else:
|
||||
file_type = FileType.CUSTOM
|
||||
return file_type
|
||||
|
||||
|
||||
def get_file_type_by_mime_type(mime_type: str) -> FileType:
|
||||
return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM
|
||||
|
||||
|
||||
class StorageKeyLoader:
|
||||
"""FileKeyLoader load the storage key from database for a list of files.
|
||||
This loader is batched, the database query count is constant regardless of the input size.
|
||||
"""
|
||||
|
||||
def __init__(self, session: Session, tenant_id: str):
|
||||
self._session = session
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id.in_(upload_file_ids),
|
||||
UploadFile.tenant_id == self._tenant_id,
|
||||
)
|
||||
|
||||
return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}
|
||||
|
||||
def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]:
|
||||
stmt = select(ToolFile).where(
|
||||
ToolFile.id.in_(tool_file_ids),
|
||||
ToolFile.tenant_id == self._tenant_id,
|
||||
)
|
||||
return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}
|
||||
|
||||
def load_storage_keys(self, files: Sequence[File]):
|
||||
"""Loads storage keys for a sequence of files by retrieving the corresponding
|
||||
`UploadFile` or `ToolFile` records from the database based on their transfer method.
|
||||
|
||||
This method doesn't modify the input sequence structure but updates the `_storage_key`
|
||||
property of each file object by extracting the relevant key from its database record.
|
||||
|
||||
Performance note: This is a batched operation where database query count remains constant
|
||||
regardless of input size. However, for optimal performance, input sequences should contain
|
||||
fewer than 1000 files. For larger collections, split into smaller batches and process each
|
||||
batch separately.
|
||||
"""
|
||||
|
||||
upload_file_ids: list[uuid.UUID] = []
|
||||
tool_file_ids: list[uuid.UUID] = []
|
||||
for file in files:
|
||||
related_model_id = file.related_id
|
||||
if file.related_id is None:
|
||||
raise ValueError("file id should not be None.")
|
||||
if file.tenant_id != self._tenant_id:
|
||||
err_msg = (
|
||||
f"invalid file, expected tenant_id={self._tenant_id}, "
|
||||
f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}"
|
||||
)
|
||||
raise ValueError(err_msg)
|
||||
model_id = uuid.UUID(related_model_id)
|
||||
|
||||
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
|
||||
upload_file_ids.append(model_id)
|
||||
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file_ids.append(model_id)
|
||||
|
||||
tool_files = self._load_tool_files(tool_file_ids)
|
||||
upload_files = self._load_upload_files(upload_file_ids)
|
||||
for file in files:
|
||||
model_id = uuid.UUID(file.related_id)
|
||||
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
|
||||
upload_file_row = upload_files.get(model_id)
|
||||
if upload_file_row is None:
|
||||
raise ValueError(f"Upload file not found for id: {model_id}")
|
||||
file.storage_key = upload_file_row.key
|
||||
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file_row = tool_files.get(model_id)
|
||||
if tool_file_row is None:
|
||||
raise ValueError(f"Tool file not found for id: {model_id}")
|
||||
file.storage_key = tool_file_row.file_key
|
||||
308
dify/api/factories/variable_factory.py
Normal file
308
dify/api/factories/variable_factory.py
Normal file
@@ -0,0 +1,308 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File
|
||||
from core.variables.exc import VariableError
|
||||
from core.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayBooleanSegment,
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.variables import (
|
||||
ArrayAnyVariable,
|
||||
ArrayBooleanVariable,
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
BooleanVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
NoneVariable,
|
||||
ObjectVariable,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
)
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedSegmentTypeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TypeMismatchError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Define the constant
|
||||
SEGMENT_TO_VARIABLE_MAP = {
|
||||
ArrayAnySegment: ArrayAnyVariable,
|
||||
ArrayBooleanSegment: ArrayBooleanVariable,
|
||||
ArrayFileSegment: ArrayFileVariable,
|
||||
ArrayNumberSegment: ArrayNumberVariable,
|
||||
ArrayObjectSegment: ArrayObjectVariable,
|
||||
ArrayStringSegment: ArrayStringVariable,
|
||||
BooleanSegment: BooleanVariable,
|
||||
FileSegment: FileVariable,
|
||||
FloatSegment: FloatVariable,
|
||||
IntegerSegment: IntegerVariable,
|
||||
NoneSegment: NoneVariable,
|
||||
ObjectSegment: ObjectVariable,
|
||||
StringSegment: StringVariable,
|
||||
}
|
||||
|
||||
|
||||
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
if not mapping.get("name"):
|
||||
raise VariableError("missing name")
|
||||
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
|
||||
|
||||
|
||||
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
if not mapping.get("name"):
|
||||
raise VariableError("missing name")
|
||||
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
|
||||
|
||||
|
||||
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
if not mapping.get("variable"):
|
||||
raise VariableError("missing variable")
|
||||
return mapping["variable"]
|
||||
|
||||
|
||||
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
|
||||
"""
|
||||
This factory function is used to create the environment variable or the conversation variable,
|
||||
not support the File type.
|
||||
"""
|
||||
if (value_type := mapping.get("value_type")) is None:
|
||||
raise VariableError("missing value type")
|
||||
if (value := mapping.get("value")) is None:
|
||||
raise VariableError("missing value")
|
||||
|
||||
result: Variable
|
||||
match value_type:
|
||||
case SegmentType.STRING:
|
||||
result = StringVariable.model_validate(mapping)
|
||||
case SegmentType.SECRET:
|
||||
result = SecretVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int):
|
||||
mapping = dict(mapping)
|
||||
mapping["value_type"] = SegmentType.INTEGER
|
||||
result = IntegerVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float):
|
||||
mapping = dict(mapping)
|
||||
mapping["value_type"] = SegmentType.FLOAT
|
||||
result = FloatVariable.model_validate(mapping)
|
||||
case SegmentType.BOOLEAN:
|
||||
result = BooleanVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||
raise VariableError(f"invalid number value {value}")
|
||||
case SegmentType.OBJECT if isinstance(value, dict):
|
||||
result = ObjectVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_STRING if isinstance(value, list):
|
||||
result = ArrayStringVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_NUMBER if isinstance(value, list):
|
||||
result = ArrayNumberVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
|
||||
result = ArrayObjectVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_BOOLEAN if isinstance(value, list):
|
||||
result = ArrayBooleanVariable.model_validate(mapping)
|
||||
case _:
|
||||
raise VariableError(f"not supported value type {value_type}")
|
||||
if result.size > dify_config.MAX_VARIABLE_SIZE:
|
||||
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
|
||||
if not result.selector:
|
||||
result = result.model_copy(update={"selector": selector})
|
||||
return cast(Variable, result)
|
||||
|
||||
|
||||
def build_segment(value: Any, /) -> Segment:
|
||||
# NOTE: If you have runtime type information available, consider using the `build_segment_with_type`
|
||||
# below
|
||||
if value is None:
|
||||
return NoneSegment()
|
||||
if isinstance(value, Segment):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return StringSegment(value=value)
|
||||
if isinstance(value, bool):
|
||||
return BooleanSegment(value=value)
|
||||
if isinstance(value, int):
|
||||
return IntegerSegment(value=value)
|
||||
if isinstance(value, float):
|
||||
return FloatSegment(value=value)
|
||||
if isinstance(value, dict):
|
||||
return ObjectSegment(value=value)
|
||||
if isinstance(value, File):
|
||||
return FileSegment(value=value)
|
||||
if isinstance(value, list):
|
||||
items = [build_segment(item) for item in value]
|
||||
types = {item.value_type for item in items}
|
||||
if all(isinstance(item, ArraySegment) for item in items):
|
||||
return ArrayAnySegment(value=value)
|
||||
elif len(types) != 1:
|
||||
if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}):
|
||||
return ArrayNumberSegment(value=value)
|
||||
return ArrayAnySegment(value=value)
|
||||
|
||||
match types.pop():
|
||||
case SegmentType.STRING:
|
||||
return ArrayStringSegment(value=value)
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
return ArrayNumberSegment(value=value)
|
||||
case SegmentType.BOOLEAN:
|
||||
return ArrayBooleanSegment(value=value)
|
||||
case SegmentType.OBJECT:
|
||||
return ArrayObjectSegment(value=value)
|
||||
case SegmentType.FILE:
|
||||
return ArrayFileSegment(value=value)
|
||||
case SegmentType.NONE:
|
||||
return ArrayAnySegment(value=value)
|
||||
case _:
|
||||
# This should be unreachable.
|
||||
raise ValueError(f"not supported value {value}")
|
||||
raise ValueError(f"not supported value {value}")
|
||||
|
||||
|
||||
_segment_factory: Mapping[SegmentType, type[Segment]] = {
|
||||
SegmentType.NONE: NoneSegment,
|
||||
SegmentType.STRING: StringSegment,
|
||||
SegmentType.INTEGER: IntegerSegment,
|
||||
SegmentType.FLOAT: FloatSegment,
|
||||
SegmentType.FILE: FileSegment,
|
||||
SegmentType.BOOLEAN: BooleanSegment,
|
||||
SegmentType.OBJECT: ObjectSegment,
|
||||
# Array types
|
||||
SegmentType.ARRAY_ANY: ArrayAnySegment,
|
||||
SegmentType.ARRAY_STRING: ArrayStringSegment,
|
||||
SegmentType.ARRAY_NUMBER: ArrayNumberSegment,
|
||||
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
|
||||
SegmentType.ARRAY_FILE: ArrayFileSegment,
|
||||
SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment,
|
||||
}
|
||||
|
||||
|
||||
def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
||||
"""
|
||||
Build a segment with explicit type checking.
|
||||
|
||||
This function creates a segment from a value while enforcing type compatibility
|
||||
with the specified segment_type. It provides stricter type validation compared
|
||||
to the standard build_segment function.
|
||||
|
||||
Args:
|
||||
segment_type: The expected SegmentType for the resulting segment
|
||||
value: The value to be converted into a segment
|
||||
|
||||
Returns:
|
||||
Segment: A segment instance of the appropriate type
|
||||
|
||||
Raises:
|
||||
TypeMismatchError: If the value type doesn't match the expected segment_type
|
||||
|
||||
Special Cases:
|
||||
- For empty list [] values, if segment_type is array[*], returns the corresponding array type
|
||||
- Type validation is performed before segment creation
|
||||
|
||||
Examples:
|
||||
>>> build_segment_with_type(SegmentType.STRING, "hello")
|
||||
StringSegment(value="hello")
|
||||
|
||||
>>> build_segment_with_type(SegmentType.ARRAY_STRING, [])
|
||||
ArrayStringSegment(value=[])
|
||||
|
||||
>>> build_segment_with_type(SegmentType.STRING, 123)
|
||||
# Raises TypeMismatchError
|
||||
"""
|
||||
# Handle None values
|
||||
if value is None:
|
||||
if segment_type == SegmentType.NONE:
|
||||
return NoneSegment()
|
||||
else:
|
||||
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None")
|
||||
|
||||
# Handle empty list special case for array types
|
||||
if isinstance(value, list) and len(value) == 0:
|
||||
if segment_type == SegmentType.ARRAY_ANY:
|
||||
return ArrayAnySegment(value=value)
|
||||
elif segment_type == SegmentType.ARRAY_STRING:
|
||||
return ArrayStringSegment(value=value)
|
||||
elif segment_type == SegmentType.ARRAY_BOOLEAN:
|
||||
return ArrayBooleanSegment(value=value)
|
||||
elif segment_type == SegmentType.ARRAY_NUMBER:
|
||||
return ArrayNumberSegment(value=value)
|
||||
elif segment_type == SegmentType.ARRAY_OBJECT:
|
||||
return ArrayObjectSegment(value=value)
|
||||
elif segment_type == SegmentType.ARRAY_FILE:
|
||||
return ArrayFileSegment(value=value)
|
||||
else:
|
||||
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
|
||||
|
||||
inferred_type = SegmentType.infer_segment_type(value)
|
||||
# Type compatibility checking
|
||||
if inferred_type is None:
|
||||
raise TypeMismatchError(
|
||||
f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}"
|
||||
)
|
||||
if inferred_type == segment_type:
|
||||
segment_class = _segment_factory[segment_type]
|
||||
return segment_class(value_type=segment_type, value=value)
|
||||
elif segment_type == SegmentType.NUMBER and inferred_type in (
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.FLOAT,
|
||||
):
|
||||
segment_class = _segment_factory[inferred_type]
|
||||
return segment_class(value_type=inferred_type, value=value)
|
||||
else:
|
||||
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
|
||||
|
||||
|
||||
def segment_to_variable(
|
||||
*,
|
||||
segment: Segment,
|
||||
selector: Sequence[str],
|
||||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str = "",
|
||||
) -> Variable:
|
||||
if isinstance(segment, Variable):
|
||||
return segment
|
||||
name = name or selector[-1]
|
||||
id = id or str(uuid4())
|
||||
|
||||
segment_type = type(segment)
|
||||
if segment_type not in SEGMENT_TO_VARIABLE_MAP:
|
||||
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
|
||||
|
||||
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
|
||||
return cast(
|
||||
Variable,
|
||||
variable_class(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
value=segment.value,
|
||||
selector=list(selector),
|
||||
),
|
||||
)
|
||||
Reference in New Issue
Block a user