dify
This commit is contained in:
0
dify/api/core/app/app_config/__init__.py
Normal file
0
dify/api/core/app/app_config/__init__.py
Normal file
49
dify/api/core/app/app_config/base_app_config_manager.py
Normal file
49
dify/api/core/app/app_config/base_app_config_manager.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
|
||||
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
|
||||
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
|
||||
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
|
||||
from core.app.app_config.features.suggested_questions_after_answer.manager import (
|
||||
SuggestedQuestionsAfterAnswerConfigManager,
|
||||
)
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class BaseAppConfigManager:
|
||||
@classmethod
|
||||
def convert_features(cls, config_dict: Mapping[str, Any], app_mode: AppMode) -> AppAdditionalFeatures:
|
||||
"""
|
||||
Convert app config to app model config
|
||||
|
||||
:param config_dict: app config
|
||||
:param app_mode: app mode
|
||||
"""
|
||||
config_dict = dict(config_dict.items())
|
||||
|
||||
additional_features = AppAdditionalFeatures()
|
||||
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
|
||||
|
||||
additional_features.file_upload = FileUploadConfigManager.convert(
|
||||
config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT}
|
||||
)
|
||||
|
||||
additional_features.opening_statement, additional_features.suggested_questions = (
|
||||
OpeningStatementConfigManager.convert(config=config_dict)
|
||||
)
|
||||
|
||||
additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
|
||||
additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict)
|
||||
|
||||
additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict)
|
||||
|
||||
additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict)
|
||||
|
||||
return additional_features
|
||||
0
dify/api/core/app/app_config/common/__init__.py
Normal file
0
dify/api/core/app/app_config/common/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
|
||||
|
||||
def get_parameters_from_feature_dict(
|
||||
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Mapping from feature dict to webapp parameters
|
||||
"""
|
||||
return {
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": DEFAULT_FILE_NUMBER_LIMITS,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
from core.app.app_config.entities import SensitiveWordAvoidanceEntity
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None:
|
||||
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
|
||||
if not sensitive_word_avoidance_dict:
|
||||
return None
|
||||
|
||||
if sensitive_word_avoidance_dict.get("enabled"):
|
||||
return SensitiveWordAvoidanceEntity(
|
||||
type=sensitive_word_avoidance_dict.get("type"),
|
||||
config=sensitive_word_avoidance_dict.get("config"),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
|
||||
) -> tuple[dict, list[str]]:
|
||||
if not config.get("sensitive_word_avoidance"):
|
||||
config["sensitive_word_avoidance"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"], dict):
|
||||
raise ValueError("sensitive_word_avoidance must be of dict type")
|
||||
|
||||
if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
|
||||
config["sensitive_word_avoidance"]["enabled"] = False
|
||||
|
||||
if config["sensitive_word_avoidance"]["enabled"]:
|
||||
if not config["sensitive_word_avoidance"].get("type"):
|
||||
raise ValueError("sensitive_word_avoidance.type is required")
|
||||
|
||||
if not only_structure_validate:
|
||||
typ = config["sensitive_word_avoidance"]["type"]
|
||||
if not isinstance(typ, str):
|
||||
raise ValueError("sensitive_word_avoidance.type must be a string")
|
||||
|
||||
sensitive_word_avoidance_config = config["sensitive_word_avoidance"].get("config")
|
||||
if sensitive_word_avoidance_config is None:
|
||||
sensitive_word_avoidance_config = {}
|
||||
if not isinstance(sensitive_word_avoidance_config, dict):
|
||||
raise ValueError("sensitive_word_avoidance.config must be a dict")
|
||||
|
||||
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
|
||||
|
||||
return config, ["sensitive_word_avoidance"]
|
||||
@@ -0,0 +1,80 @@
|
||||
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
|
||||
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
|
||||
|
||||
|
||||
class AgentConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> AgentEntity | None:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]:
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
agent_strategy = agent_dict.get("strategy", "cot")
|
||||
|
||||
if agent_strategy == "function_call":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
elif agent_strategy in {"cot", "react"}:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
# old configs, try to detect default strategy
|
||||
if config["model"]["provider"] == "openai":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) >= 4:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties = {
|
||||
"provider_type": tool["provider_type"],
|
||||
"provider_id": tool["provider_id"],
|
||||
"tool_name": tool["tool_name"],
|
||||
"tool_parameters": tool.get("tool_parameters", {}),
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
|
||||
|
||||
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
|
||||
"react_router",
|
||||
"router",
|
||||
}:
|
||||
agent_prompt = agent_dict.get("prompt", None) or {}
|
||||
# check model mode
|
||||
model_mode = config.get("model", {}).get("mode", "completion")
|
||||
if model_mode == "completion":
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get(
|
||||
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"]
|
||||
),
|
||||
next_iteration=agent_prompt.get(
|
||||
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"]
|
||||
),
|
||||
)
|
||||
else:
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get(
|
||||
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"]
|
||||
),
|
||||
next_iteration=agent_prompt.get(
|
||||
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"]
|
||||
),
|
||||
)
|
||||
|
||||
return AgentEntity(
|
||||
provider=config["model"]["provider"],
|
||||
model=config["model"]["name"],
|
||||
strategy=strategy,
|
||||
prompt=agent_prompt_entity,
|
||||
tools=agent_tools,
|
||||
max_iteration=agent_dict.get("max_iteration", 10),
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,254 @@
|
||||
import uuid
|
||||
from typing import Literal, cast
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
DatasetRetrieveConfigEntity,
|
||||
MetadataFilteringCondition,
|
||||
ModelConfig,
|
||||
)
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from models.model import AppMode
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
||||
class DatasetConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> DatasetEntity | None:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
dataset_ids = []
|
||||
if "datasets" in config.get("dataset_configs", {}):
|
||||
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
|
||||
|
||||
for dataset in datasets.get("datasets", []):
|
||||
keys = list(dataset.keys())
|
||||
if len(keys) == 0 or keys[0] != "dataset":
|
||||
continue
|
||||
|
||||
dataset = dataset["dataset"]
|
||||
|
||||
if "enabled" not in dataset or not dataset["enabled"]:
|
||||
continue
|
||||
|
||||
dataset_id = dataset.get("id", None)
|
||||
if dataset_id:
|
||||
dataset_ids.append(dataset_id)
|
||||
|
||||
if (
|
||||
"agent_mode" in config
|
||||
and config["agent_mode"]
|
||||
and "enabled" in config["agent_mode"]
|
||||
and config["agent_mode"]["enabled"]
|
||||
):
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) == 1:
|
||||
# old standard
|
||||
key = list(tool.keys())[0]
|
||||
|
||||
if key != "dataset":
|
||||
continue
|
||||
|
||||
tool_item = tool[key]
|
||||
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
||||
dataset_id = tool_item["id"]
|
||||
dataset_ids.append(dataset_id)
|
||||
|
||||
if len(dataset_ids) == 0:
|
||||
return None
|
||||
|
||||
# dataset configs
|
||||
if "dataset_configs" in config and config.get("dataset_configs"):
|
||||
dataset_configs = config.get("dataset_configs")
|
||||
else:
|
||||
dataset_configs = {"retrieval_model": "multiple"}
|
||||
if dataset_configs is None:
|
||||
return None
|
||||
query_variable = config.get("dataset_query_variable")
|
||||
|
||||
metadata_model_config_dict = dataset_configs.get("metadata_model_config")
|
||||
metadata_filtering_conditions_dict = dataset_configs.get("metadata_filtering_conditions")
|
||||
|
||||
if dataset_configs["retrieval_model"] == "single":
|
||||
return DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
metadata_filtering_mode=cast(
|
||||
Literal["disabled", "automatic", "manual"],
|
||||
dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
),
|
||||
metadata_model_config=ModelConfig(**metadata_model_config_dict)
|
||||
if isinstance(metadata_model_config_dict, dict)
|
||||
else None,
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
|
||||
if isinstance(metadata_filtering_conditions_dict, dict)
|
||||
else None,
|
||||
),
|
||||
)
|
||||
else:
|
||||
score_threshold_val = dataset_configs.get("score_threshold")
|
||||
reranking_model_val = dataset_configs.get("reranking_model")
|
||||
weights_val = dataset_configs.get("weights")
|
||||
|
||||
return DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
top_k=int(dataset_configs.get("top_k", 4)),
|
||||
score_threshold=float(score_threshold_val)
|
||||
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
|
||||
else None,
|
||||
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
|
||||
weights=weights_val if isinstance(weights_val, dict) else None,
|
||||
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
|
||||
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
||||
metadata_filtering_mode=cast(
|
||||
Literal["disabled", "automatic", "manual"],
|
||||
dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
),
|
||||
metadata_model_config=ModelConfig(**metadata_model_config_dict)
|
||||
if isinstance(metadata_model_config_dict, dict)
|
||||
else None,
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
|
||||
if isinstance(metadata_filtering_conditions_dict, dict)
|
||||
else None,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for dataset feature
|
||||
|
||||
:param tenant_id: tenant ID
|
||||
:param app_mode: app mode
|
||||
:param config: app model config args
|
||||
"""
|
||||
# Extract dataset config for legacy compatibility
|
||||
config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config)
|
||||
|
||||
# dataset_configs
|
||||
if "dataset_configs" not in config or not config.get("dataset_configs"):
|
||||
config["dataset_configs"] = {}
|
||||
config["dataset_configs"]["retrieval_model"] = config["dataset_configs"].get("retrieval_model", "single")
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
if "datasets" not in config["dataset_configs"] or not config["dataset_configs"].get("datasets"):
|
||||
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
|
||||
|
||||
need_manual_query_datasets = config.get("dataset_configs", {}).get("datasets", {}).get("datasets")
|
||||
|
||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||
# Only check when mode is completion
|
||||
dataset_query_variable = config.get("dataset_query_variable")
|
||||
|
||||
if not dataset_query_variable:
|
||||
raise ValueError("Dataset query variable is required when dataset is exist")
|
||||
|
||||
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
|
||||
|
||||
@classmethod
|
||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
|
||||
"""
|
||||
Extract dataset config for legacy compatibility
|
||||
|
||||
:param tenant_id: tenant ID
|
||||
:param app_mode: app mode
|
||||
:param config: app model config args
|
||||
"""
|
||||
# Extract dataset config for legacy compatibility
|
||||
if "agent_mode" not in config or not config.get("agent_mode"):
|
||||
config["agent_mode"] = {}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
|
||||
# enabled
|
||||
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
|
||||
config["agent_mode"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["agent_mode"]["enabled"], bool):
|
||||
raise ValueError("enabled in agent_mode must be of boolean type")
|
||||
|
||||
# tools
|
||||
if "tools" not in config["agent_mode"] or not config["agent_mode"].get("tools"):
|
||||
config["agent_mode"]["tools"] = []
|
||||
|
||||
if not isinstance(config["agent_mode"]["tools"], list):
|
||||
raise ValueError("tools in agent_mode must be a list of objects")
|
||||
|
||||
# strategy
|
||||
if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"):
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER
|
||||
|
||||
has_datasets = False
|
||||
if config.get("agent_mode", {}).get("strategy") in {
|
||||
PlanningStrategy.ROUTER,
|
||||
PlanningStrategy.REACT_ROUTER,
|
||||
}:
|
||||
for tool in config.get("agent_mode", {}).get("tools", []):
|
||||
key = list(tool.keys())[0]
|
||||
if key == "dataset":
|
||||
# old style, use tool name as key
|
||||
tool_item = tool[key]
|
||||
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
tool_item["enabled"] = False
|
||||
|
||||
if not isinstance(tool_item["enabled"], bool):
|
||||
raise ValueError("enabled in agent_mode.tools must be of boolean type")
|
||||
|
||||
if "id" not in tool_item:
|
||||
raise ValueError("id is required in dataset")
|
||||
|
||||
try:
|
||||
uuid.UUID(tool_item["id"])
|
||||
except ValueError:
|
||||
raise ValueError("id in dataset must be of UUID type")
|
||||
|
||||
if not cls.is_dataset_exists(tenant_id, tool_item["id"]):
|
||||
raise ValueError("Dataset ID does not exist, please check your permission.")
|
||||
|
||||
has_datasets = True
|
||||
|
||||
need_manual_query_datasets = has_datasets and config.get("agent_mode", {}).get("enabled")
|
||||
|
||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||
# Only check when mode is completion
|
||||
dataset_query_variable = config.get("dataset_query_variable")
|
||||
|
||||
if not dataset_query_variable:
|
||||
raise ValueError("Dataset query variable is required when dataset is exist")
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def is_dataset_exists(cls, tenant_id: str, dataset_id: str) -> bool:
|
||||
# verify if the dataset ID exists
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
if not dataset:
|
||||
return False
|
||||
|
||||
if dataset.tenant_id != tenant_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -0,0 +1,91 @@
|
||||
from typing import cast
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
|
||||
class ModelConfigConverter:
|
||||
@classmethod
|
||||
def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
|
||||
"""
|
||||
Convert app model config dict to entity.
|
||||
:param app_config: app config
|
||||
:raises ProviderTokenNotInitError: provider token not init error
|
||||
:return: app orchestration config entity
|
||||
"""
|
||||
model_config = app_config.model
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_name = model_config.model
|
||||
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# check model credentials
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM, model=model_config.model
|
||||
)
|
||||
|
||||
if model_credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_config.model, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
model_name = model_config.model
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = model_config.parameters
|
||||
stop = []
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||
|
||||
# get model mode
|
||||
model_mode = model_config.mode
|
||||
if not model_mode:
|
||||
model_mode = LLMMode.CHAT
|
||||
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
||||
try:
|
||||
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE])
|
||||
except ValueError:
|
||||
# Fall back to CHAT mode if the stored value is invalid
|
||||
model_mode = LLMMode.CHAT
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
return ModelConfigWithCredentialsEntity(
|
||||
provider=model_config.provider,
|
||||
model=model_config.model,
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
@@ -0,0 +1,122 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
class ModelConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> ModelConfigEntity:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
# model config
|
||||
model_config = config.get("model")
|
||||
|
||||
if not model_config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
completion_params = model_config.get("completion_params")
|
||||
stop = []
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = model_config.get("mode")
|
||||
|
||||
return ModelConfigEntity(
|
||||
provider=config["model"]["provider"],
|
||||
model=config["model"]["name"],
|
||||
mode=model_mode,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for model config
|
||||
|
||||
:param tenant_id: tenant id
|
||||
:param config: app model config args
|
||||
"""
|
||||
if "model" not in config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
if not isinstance(config["model"], dict):
|
||||
raise ValueError("model must be of object type")
|
||||
|
||||
# model.provider
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
model_provider_names = [provider.provider for provider in provider_entities]
|
||||
if "provider" not in config["model"]:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
if "/" not in config["model"]["provider"]:
|
||||
config["model"]["provider"] = str(ModelProviderID(config["model"]["provider"]))
|
||||
|
||||
if config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
# model.name
|
||||
if "name" not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
models = provider_manager.get_configurations(tenant_id).get_models(
|
||||
provider=config["model"]["provider"], model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if not models:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
model_ids = [m.model for m in models]
|
||||
if config["model"]["name"] not in model_ids:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
model_mode = None
|
||||
for model in models:
|
||||
if model.model == config["model"]["name"]:
|
||||
model_mode = model.model_properties.get(ModelPropertyKey.MODE)
|
||||
break
|
||||
|
||||
# model.mode
|
||||
if model_mode:
|
||||
config["model"]["mode"] = model_mode
|
||||
else:
|
||||
config["model"]["mode"] = "completion"
|
||||
|
||||
# model.completion_params
|
||||
if "completion_params" not in config["model"]:
|
||||
raise ValueError("model.completion_params is required")
|
||||
|
||||
config["model"]["completion_params"] = cls.validate_model_completion_params(
|
||||
config["model"]["completion_params"]
|
||||
)
|
||||
|
||||
return dict(config), ["model"]
|
||||
|
||||
@classmethod
|
||||
def validate_model_completion_params(cls, cp: dict):
|
||||
# model.completion_params
|
||||
if not isinstance(cp, dict):
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
# stop
|
||||
if "stop" not in cp:
|
||||
cp["stop"] = []
|
||||
elif not isinstance(cp["stop"], list):
|
||||
raise ValueError("stop in model.completion_params must be of list type")
|
||||
|
||||
if len(cp["stop"]) > 4:
|
||||
raise ValueError("stop sequences must be less than 4")
|
||||
|
||||
return cp
|
||||
@@ -0,0 +1,142 @@
|
||||
from core.app.app_config.entities import (
|
||||
AdvancedChatMessageEntity,
|
||||
AdvancedChatPromptTemplateEntity,
|
||||
AdvancedCompletionPromptTemplateEntity,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class PromptTemplateConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> PromptTemplateEntity:
|
||||
if not config.get("prompt_type"):
|
||||
raise ValueError("prompt_type is required")
|
||||
|
||||
prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"])
|
||||
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
simple_prompt_template = config.get("pre_prompt", "")
|
||||
return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template)
|
||||
else:
|
||||
advanced_chat_prompt_template = None
|
||||
chat_prompt_config = config.get("chat_prompt_config", {})
|
||||
if chat_prompt_config:
|
||||
chat_prompt_messages = []
|
||||
for message in chat_prompt_config.get("prompt", []):
|
||||
text = message.get("text")
|
||||
if not isinstance(text, str):
|
||||
raise ValueError("message text must be a string")
|
||||
role = message.get("role")
|
||||
if not isinstance(role, str):
|
||||
raise ValueError("message role must be a string")
|
||||
chat_prompt_messages.append(
|
||||
AdvancedChatMessageEntity(text=text, role=PromptMessageRole.value_of(role))
|
||||
)
|
||||
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
|
||||
|
||||
advanced_completion_prompt_template = None
|
||||
completion_prompt_config = config.get("completion_prompt_config", {})
|
||||
if completion_prompt_config:
|
||||
completion_prompt_template_params = {
|
||||
"prompt": completion_prompt_config["prompt"]["text"],
|
||||
}
|
||||
|
||||
if "conversation_histories_role" in completion_prompt_config:
|
||||
completion_prompt_template_params["role_prefix"] = {
|
||||
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
|
||||
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
|
||||
}
|
||||
|
||||
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
|
||||
**completion_prompt_template_params
|
||||
)
|
||||
|
||||
return PromptTemplateEntity(
|
||||
prompt_type=prompt_type,
|
||||
advanced_chat_prompt_template=advanced_chat_prompt_template,
|
||||
advanced_completion_prompt_template=advanced_completion_prompt_template,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate pre_prompt and set defaults for prompt feature
|
||||
depending on the config['model']
|
||||
|
||||
:param app_mode: app mode
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("prompt_type"):
|
||||
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE
|
||||
|
||||
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
|
||||
if config["prompt_type"] not in prompt_type_vals:
|
||||
raise ValueError(f"prompt_type must be in {prompt_type_vals}")
|
||||
|
||||
# chat_prompt_config
|
||||
if not config.get("chat_prompt_config"):
|
||||
config["chat_prompt_config"] = {}
|
||||
|
||||
if not isinstance(config["chat_prompt_config"], dict):
|
||||
raise ValueError("chat_prompt_config must be of object type")
|
||||
|
||||
# completion_prompt_config
|
||||
if not config.get("completion_prompt_config"):
|
||||
config["completion_prompt_config"] = {}
|
||||
|
||||
if not isinstance(config["completion_prompt_config"], dict):
|
||||
raise ValueError("completion_prompt_config must be of object type")
|
||||
|
||||
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED:
|
||||
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
|
||||
raise ValueError(
|
||||
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"
|
||||
)
|
||||
|
||||
model_mode_vals = [mode.value for mode in ModelMode]
|
||||
if config["model"]["mode"] not in model_mode_vals:
|
||||
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
|
||||
|
||||
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION:
|
||||
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
|
||||
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
|
||||
|
||||
if not user_prefix:
|
||||
config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human"
|
||||
|
||||
if not assistant_prefix:
|
||||
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
|
||||
|
||||
if config["model"]["mode"] == ModelMode.CHAT:
|
||||
prompt_list = config["chat_prompt_config"]["prompt"]
|
||||
|
||||
if len(prompt_list) > 10:
|
||||
raise ValueError("prompt messages must be less than 10")
|
||||
else:
|
||||
# pre_prompt, for simple mode
|
||||
if not config.get("pre_prompt"):
|
||||
config["pre_prompt"] = ""
|
||||
|
||||
if not isinstance(config["pre_prompt"], str):
|
||||
raise ValueError("pre_prompt must be of string type")
|
||||
|
||||
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
|
||||
|
||||
@classmethod
|
||||
def validate_post_prompt_and_set_defaults(cls, config: dict):
|
||||
"""
|
||||
Validate post_prompt and set defaults for prompt feature
|
||||
|
||||
:param config: app model config args
|
||||
"""
|
||||
# post_prompt
|
||||
if not config.get("post_prompt"):
|
||||
config["post_prompt"] = ""
|
||||
|
||||
if not isinstance(config["post_prompt"], str):
|
||||
raise ValueError("post_prompt must be of string type")
|
||||
|
||||
return config
|
||||
@@ -0,0 +1,189 @@
|
||||
import re
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
|
||||
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
[
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
VariableEntityType.NUMBER,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
VariableEntityType.CHECKBOX,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class BasicVariablesConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
external_data_variables = []
|
||||
variable_entities = []
|
||||
|
||||
# old external_data_tools
|
||||
external_data_tools = config.get("external_data_tools", [])
|
||||
for external_data_tool in external_data_tools:
|
||||
if "enabled" not in external_data_tool or not external_data_tool["enabled"]:
|
||||
continue
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=external_data_tool["variable"],
|
||||
type=external_data_tool["type"],
|
||||
config=external_data_tool["config"],
|
||||
)
|
||||
)
|
||||
|
||||
# variables and external_data_tools
|
||||
for variables in config.get("user_input_form", []):
|
||||
variable_type = list(variables.keys())[0]
|
||||
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
|
||||
variable = variables[variable_type]
|
||||
if "config" not in variable:
|
||||
continue
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=variable["variable"], type=variable["type"], config=variable["config"]
|
||||
)
|
||||
)
|
||||
elif variable_type in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
VariableEntityType.NUMBER,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.CHECKBOX,
|
||||
}:
|
||||
variable = variables[variable_type]
|
||||
variable_entities.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
)
|
||||
)
|
||||
|
||||
return variable_entities, external_data_variables
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for user input form
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param config: app model config args
|
||||
"""
|
||||
related_config_keys = []
|
||||
config, current_related_config_keys = cls.validate_variables_and_set_defaults(config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
config, current_related_config_keys = cls.validate_external_data_tools_and_set_defaults(tenant_id, config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
return config, related_config_keys
|
||||
|
||||
@classmethod
|
||||
def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for user input form
|
||||
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("user_input_form"):
|
||||
config["user_input_form"] = []
|
||||
|
||||
if not isinstance(config["user_input_form"], list):
|
||||
raise ValueError("user_input_form must be a list of objects")
|
||||
|
||||
variables = []
|
||||
for item in config["user_input_form"]:
|
||||
key = list(item.keys())[0]
|
||||
# if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
|
||||
if key not in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
VariableEntityType.NUMBER,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
VariableEntityType.CHECKBOX,
|
||||
}:
|
||||
allowed_keys = ", ".join(i.value for i in _ALLOWED_VARIABLE_ENTITY_TYPE)
|
||||
raise ValueError(f"Keys in user_input_form list can only be {allowed_keys}")
|
||||
|
||||
form_item = item[key]
|
||||
if "label" not in form_item:
|
||||
raise ValueError("label is required in user_input_form")
|
||||
|
||||
if not isinstance(form_item["label"], str):
|
||||
raise ValueError("label in user_input_form must be of string type")
|
||||
|
||||
if "variable" not in form_item:
|
||||
raise ValueError("variable is required in user_input_form")
|
||||
|
||||
if not isinstance(form_item["variable"], str):
|
||||
raise ValueError("variable in user_input_form must be of string type")
|
||||
|
||||
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
|
||||
if pattern.match(form_item["variable"]) is None:
|
||||
raise ValueError("variable in user_input_form must be a string, and cannot start with a number")
|
||||
|
||||
variables.append(form_item["variable"])
|
||||
|
||||
if "required" not in form_item or not form_item["required"]:
|
||||
form_item["required"] = False
|
||||
|
||||
if not isinstance(form_item["required"], bool):
|
||||
raise ValueError("required in user_input_form must be of boolean type")
|
||||
|
||||
if key == "select":
|
||||
if "options" not in form_item or not form_item["options"]:
|
||||
form_item["options"] = []
|
||||
|
||||
if not isinstance(form_item["options"], list):
|
||||
raise ValueError("options in user_input_form must be a list of strings")
|
||||
|
||||
if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]:
|
||||
raise ValueError("default value in user_input_form must be in the options list")
|
||||
|
||||
return config, ["user_input_form"]
|
||||
|
||||
@classmethod
|
||||
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for external data fetch feature
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("external_data_tools"):
|
||||
config["external_data_tools"] = []
|
||||
|
||||
if not isinstance(config["external_data_tools"], list):
|
||||
raise ValueError("external_data_tools must be of list type")
|
||||
|
||||
for tool in config["external_data_tools"]:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
tool["enabled"] = False
|
||||
|
||||
if not tool["enabled"]:
|
||||
continue
|
||||
|
||||
if "type" not in tool or not tool["type"]:
|
||||
raise ValueError("external_data_tools[].type is required")
|
||||
|
||||
typ = tool["type"]
|
||||
config = tool["config"]
|
||||
|
||||
ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config)
|
||||
|
||||
return config, ["external_data_tools"]
|
||||
336
dify/api/core/app/app_config/entities.py
Normal file
336
dify/api/core/app/app_config/entities.py
Normal file
@@ -0,0 +1,336 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.file import FileTransferMethod, FileType, FileUploadConfig
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class ModelConfigEntity(BaseModel):
|
||||
"""
|
||||
Model Config Entity.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
mode: str | None = None
|
||||
parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
stop: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AdvancedChatMessageEntity(BaseModel):
|
||||
"""
|
||||
Advanced Chat Message Entity.
|
||||
"""
|
||||
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
|
||||
|
||||
class AdvancedChatPromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Advanced Chat Prompt Template Entity.
|
||||
"""
|
||||
|
||||
messages: list[AdvancedChatMessageEntity]
|
||||
|
||||
|
||||
class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Advanced Completion Prompt Template Entity.
|
||||
"""
|
||||
|
||||
class RolePrefixEntity(BaseModel):
|
||||
"""
|
||||
Role Prefix Entity.
|
||||
"""
|
||||
|
||||
user: str
|
||||
assistant: str
|
||||
|
||||
prompt: str
|
||||
role_prefix: RolePrefixEntity | None = None
|
||||
|
||||
|
||||
class PromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Prompt Template Entity.
|
||||
"""
|
||||
|
||||
class PromptType(StrEnum):
|
||||
"""
|
||||
Prompt Type.
|
||||
'simple', 'advanced'
|
||||
"""
|
||||
|
||||
SIMPLE = auto()
|
||||
ADVANCED = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid prompt type value {value}")
|
||||
|
||||
prompt_type: PromptType
|
||||
simple_prompt_template: str | None = None
|
||||
advanced_chat_prompt_template: AdvancedChatPromptTemplateEntity | None = None
|
||||
advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None
|
||||
|
||||
|
||||
class VariableEntityType(StrEnum):
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
PARAGRAPH = "paragraph"
|
||||
NUMBER = "number"
|
||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||
FILE = "file"
|
||||
FILE_LIST = "file-list"
|
||||
CHECKBOX = "checkbox"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
"""
|
||||
Variable Entity.
|
||||
"""
|
||||
|
||||
# `variable` records the name of the variable in user inputs.
|
||||
variable: str
|
||||
label: str
|
||||
description: str = ""
|
||||
type: VariableEntityType
|
||||
required: bool = False
|
||||
hide: bool = False
|
||||
default: Any = None
|
||||
max_length: int | None = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
def convert_none_description(cls, v: Any) -> str:
|
||||
return v or ""
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@classmethod
|
||||
def convert_none_options(cls, v: Any) -> Sequence[str]:
|
||||
return v or []
|
||||
|
||||
|
||||
class RagPipelineVariableEntity(VariableEntity):
|
||||
"""
|
||||
Rag Pipeline Variable Entity.
|
||||
"""
|
||||
|
||||
tooltips: str | None = None
|
||||
placeholder: str | None = None
|
||||
belong_to_node_id: str
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
External Data Variable Entity.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
type: str
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
SupportedComparisonOperator = Literal[
|
||||
# for string or array
|
||||
"contains",
|
||||
"not contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
">",
|
||||
"<",
|
||||
"≥",
|
||||
"≤",
|
||||
# for time
|
||||
"before",
|
||||
"after",
|
||||
]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
provider: str
|
||||
name: str
|
||||
mode: LLMMode
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Condition detail
|
||||
"""
|
||||
|
||||
name: str
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None | int | float = None
|
||||
|
||||
|
||||
class MetadataFilteringCondition(BaseModel):
|
||||
"""
|
||||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
||||
|
||||
|
||||
class DatasetRetrieveConfigEntity(BaseModel):
|
||||
"""
|
||||
Dataset Retrieve Config Entity.
|
||||
"""
|
||||
|
||||
class RetrieveStrategy(StrEnum):
|
||||
"""
|
||||
Dataset Retrieve Strategy.
|
||||
'single' or 'multiple'
|
||||
"""
|
||||
|
||||
SINGLE = auto()
|
||||
MULTIPLE = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid retrieve strategy value {value}")
|
||||
|
||||
query_variable: str | None = None # Only when app mode is completion
|
||||
|
||||
retrieve_strategy: RetrieveStrategy
|
||||
top_k: int | None = None
|
||||
score_threshold: float | None = 0.0
|
||||
rerank_mode: str | None = "reranking_model"
|
||||
reranking_model: dict | None = None
|
||||
weights: dict | None = None
|
||||
reranking_enabled: bool | None = True
|
||||
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
|
||||
metadata_model_config: ModelConfig | None = None
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None = None
|
||||
|
||||
|
||||
class DatasetEntity(BaseModel):
|
||||
"""
|
||||
Dataset Config Entity.
|
||||
"""
|
||||
|
||||
dataset_ids: list[str]
|
||||
retrieve_config: DatasetRetrieveConfigEntity
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceEntity(BaseModel):
|
||||
"""
|
||||
Sensitive Word Avoidance Entity.
|
||||
"""
|
||||
|
||||
type: str
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TextToSpeechEntity(BaseModel):
|
||||
"""
|
||||
Sensitive Word Avoidance Entity.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
voice: str | None = None
|
||||
language: str | None = None
|
||||
|
||||
|
||||
class TracingConfigEntity(BaseModel):
|
||||
"""
|
||||
Tracing Config Entity.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
tracing_provider: str
|
||||
|
||||
|
||||
class AppAdditionalFeatures(BaseModel):
|
||||
file_upload: FileUploadConfig | None = None
|
||||
opening_statement: str | None = None
|
||||
suggested_questions: list[str] = []
|
||||
suggested_questions_after_answer: bool = False
|
||||
show_retrieve_source: bool = False
|
||||
more_like_this: bool = False
|
||||
speech_to_text: bool = False
|
||||
text_to_speech: TextToSpeechEntity | None = None
|
||||
trace_config: TracingConfigEntity | None = None
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""
|
||||
Application Config Entity.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
app_mode: AppMode
|
||||
additional_features: AppAdditionalFeatures | None = None
|
||||
variables: list[VariableEntity] = []
|
||||
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None
|
||||
|
||||
|
||||
class EasyUIBasedAppModelConfigFrom(StrEnum):
|
||||
"""
|
||||
App Model Config From.
|
||||
"""
|
||||
|
||||
ARGS = auto()
|
||||
APP_LATEST_CONFIG = "app-latest-config"
|
||||
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
|
||||
|
||||
|
||||
class EasyUIBasedAppConfig(AppConfig):
|
||||
"""
|
||||
Easy UI Based App Config Entity.
|
||||
"""
|
||||
|
||||
app_model_config_from: EasyUIBasedAppModelConfigFrom
|
||||
app_model_config_id: str
|
||||
app_model_config_dict: dict
|
||||
model: ModelConfigEntity
|
||||
prompt_template: PromptTemplateEntity
|
||||
dataset: DatasetEntity | None = None
|
||||
external_data_variables: list[ExternalDataVariableEntity] = []
|
||||
|
||||
|
||||
class WorkflowUIBasedAppConfig(AppConfig):
|
||||
"""
|
||||
Workflow UI Based App Config Entity.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
0
dify/api/core/app/app_config/features/__init__.py
Normal file
0
dify/api/core/app/app_config/features/__init__.py
Normal file
43
dify/api/core/app/app_config/features/file_upload/manager.py
Normal file
43
dify/api/core/app/app_config/features/file_upload/manager.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
from core.file import FileUploadConfig
|
||||
|
||||
|
||||
class FileUploadConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: Mapping[str, Any], is_vision: bool = True):
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
:param config: model config args
|
||||
:param is_vision: if True, the feature is vision feature
|
||||
"""
|
||||
file_upload_dict = config.get("file_upload")
|
||||
if file_upload_dict:
|
||||
if file_upload_dict.get("enabled"):
|
||||
transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
|
||||
file_upload_dict["image_config"] = {
|
||||
"number_limits": file_upload_dict.get("number_limits", DEFAULT_FILE_NUMBER_LIMITS),
|
||||
"transfer_methods": transform_methods,
|
||||
}
|
||||
|
||||
if is_vision:
|
||||
file_upload_dict["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "high")
|
||||
|
||||
return FileUploadConfig.model_validate(file_upload_dict)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for file upload feature
|
||||
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("file_upload"):
|
||||
config["file_upload"] = {}
|
||||
else:
|
||||
FileUploadConfig.model_validate(config["file_upload"])
|
||||
|
||||
return config, ["file_upload"]
|
||||
@@ -0,0 +1,32 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
|
||||
class MoreLikeThisConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class AppConfigModel(BaseModel):
|
||||
more_like_this: MoreLikeThisConfig = Field(default_factory=MoreLikeThisConfig)
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class MoreLikeThisConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
validated_config, _ = cls.validate_and_set_defaults(config)
|
||||
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
try:
|
||||
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
|
||||
except ValidationError:
|
||||
raise ValueError(
|
||||
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
class OpeningStatementConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> tuple[str, list]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
# opening statement
|
||||
opening_statement = config.get("opening_statement", "")
|
||||
|
||||
# suggested questions
|
||||
suggested_questions_list = config.get("suggested_questions", [])
|
||||
|
||||
return opening_statement, suggested_questions_list
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for opening statement feature
|
||||
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("opening_statement"):
|
||||
config["opening_statement"] = ""
|
||||
|
||||
if not isinstance(config["opening_statement"], str):
|
||||
raise ValueError("opening_statement must be of string type")
|
||||
|
||||
# suggested_questions
|
||||
if not config.get("suggested_questions"):
|
||||
config["suggested_questions"] = []
|
||||
|
||||
if not isinstance(config["suggested_questions"], list):
|
||||
raise ValueError("suggested_questions must be of list type")
|
||||
|
||||
for question in config["suggested_questions"]:
|
||||
if not isinstance(question, str):
|
||||
raise ValueError("Elements in suggested_questions list must be of string type")
|
||||
|
||||
return config, ["opening_statement", "suggested_questions"]
|
||||
@@ -0,0 +1,31 @@
|
||||
class RetrievalResourceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
show_retrieve_source = False
|
||||
retriever_resource_dict = config.get("retriever_resource")
|
||||
if retriever_resource_dict:
|
||||
if retriever_resource_dict.get("enabled"):
|
||||
show_retrieve_source = True
|
||||
|
||||
return show_retrieve_source
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for retriever resource feature
|
||||
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("retriever_resource"):
|
||||
config["retriever_resource"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["retriever_resource"], dict):
|
||||
raise ValueError("retriever_resource must be of dict type")
|
||||
|
||||
if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]:
|
||||
config["retriever_resource"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["retriever_resource"]["enabled"], bool):
|
||||
raise ValueError("enabled in retriever_resource must be of boolean type")
|
||||
|
||||
return config, ["retriever_resource"]
|
||||
@@ -0,0 +1,36 @@
|
||||
class SpeechToTextConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
speech_to_text = False
|
||||
speech_to_text_dict = config.get("speech_to_text")
|
||||
if speech_to_text_dict:
|
||||
if speech_to_text_dict.get("enabled"):
|
||||
speech_to_text = True
|
||||
|
||||
return speech_to_text
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for speech to text feature
|
||||
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("speech_to_text"):
|
||||
config["speech_to_text"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["speech_to_text"], dict):
|
||||
raise ValueError("speech_to_text must be of dict type")
|
||||
|
||||
if "enabled" not in config["speech_to_text"] or not config["speech_to_text"]["enabled"]:
|
||||
config["speech_to_text"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["speech_to_text"]["enabled"], bool):
|
||||
raise ValueError("enabled in speech_to_text must be of boolean type")
|
||||
|
||||
return config, ["speech_to_text"]
|
||||
@@ -0,0 +1,39 @@
|
||||
class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
suggested_questions_after_answer = False
|
||||
suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer")
|
||||
if suggested_questions_after_answer_dict:
|
||||
if suggested_questions_after_answer_dict.get("enabled"):
|
||||
suggested_questions_after_answer = True
|
||||
|
||||
return suggested_questions_after_answer
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for suggested questions feature
|
||||
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("suggested_questions_after_answer"):
|
||||
config["suggested_questions_after_answer"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["suggested_questions_after_answer"], dict):
|
||||
raise ValueError("suggested_questions_after_answer must be of dict type")
|
||||
|
||||
if (
|
||||
"enabled" not in config["suggested_questions_after_answer"]
|
||||
or not config["suggested_questions_after_answer"]["enabled"]
|
||||
):
|
||||
config["suggested_questions_after_answer"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
|
||||
raise ValueError("enabled in suggested_questions_after_answer must be of boolean type")
|
||||
|
||||
return config, ["suggested_questions_after_answer"]
|
||||
@@ -0,0 +1,45 @@
|
||||
from core.app.app_config.entities import TextToSpeechEntity
|
||||
|
||||
|
||||
class TextToSpeechConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict):
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
text_to_speech = None
|
||||
text_to_speech_dict = config.get("text_to_speech")
|
||||
if text_to_speech_dict:
|
||||
if text_to_speech_dict.get("enabled"):
|
||||
text_to_speech = TextToSpeechEntity(
|
||||
enabled=text_to_speech_dict.get("enabled"),
|
||||
voice=text_to_speech_dict.get("voice"),
|
||||
language=text_to_speech_dict.get("language"),
|
||||
)
|
||||
|
||||
return text_to_speech
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for text to speech feature
|
||||
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("text_to_speech"):
|
||||
config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""}
|
||||
|
||||
if not isinstance(config["text_to_speech"], dict):
|
||||
raise ValueError("text_to_speech must be of dict type")
|
||||
|
||||
if "enabled" not in config["text_to_speech"] or not config["text_to_speech"]["enabled"]:
|
||||
config["text_to_speech"]["enabled"] = False
|
||||
config["text_to_speech"]["voice"] = ""
|
||||
config["text_to_speech"]["language"] = ""
|
||||
|
||||
if not isinstance(config["text_to_speech"]["enabled"], bool):
|
||||
raise ValueError("enabled in text_to_speech must be of boolean type")
|
||||
|
||||
return config, ["text_to_speech"]
|
||||
@@ -0,0 +1,69 @@
|
||||
import re
|
||||
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowVariablesConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, workflow: Workflow) -> list[VariableEntity]:
|
||||
"""
|
||||
Convert workflow start variables to variables
|
||||
|
||||
:param workflow: workflow instance
|
||||
"""
|
||||
variables = []
|
||||
|
||||
# find start node
|
||||
user_input_form = workflow.user_input_form()
|
||||
|
||||
# variables
|
||||
for variable in user_input_form:
|
||||
variables.append(VariableEntity.model_validate(variable))
|
||||
|
||||
return variables
|
||||
|
||||
@classmethod
|
||||
def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]:
|
||||
"""
|
||||
Convert workflow start variables to variables
|
||||
|
||||
:param workflow: workflow instance
|
||||
"""
|
||||
variables = []
|
||||
|
||||
# get second step node
|
||||
rag_pipeline_variables = workflow.rag_pipeline_variables
|
||||
if not rag_pipeline_variables:
|
||||
return []
|
||||
variables_map = {item["variable"]: item for item in rag_pipeline_variables}
|
||||
|
||||
# get datasource node data
|
||||
datasource_node_data = None
|
||||
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||
for datasource_node in datasource_nodes:
|
||||
if datasource_node.get("id") == start_node_id:
|
||||
datasource_node_data = datasource_node.get("data", {})
|
||||
break
|
||||
if datasource_node_data:
|
||||
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
|
||||
|
||||
for _, value in datasource_parameters.items():
|
||||
if value.get("value") and isinstance(value.get("value"), str):
|
||||
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
|
||||
match = re.match(pattern, value["value"])
|
||||
if match:
|
||||
full_path = match.group(1)
|
||||
last_part = full_path.split(".")[-1]
|
||||
variables_map.pop(last_part, None)
|
||||
if value.get("value") and isinstance(value.get("value"), list):
|
||||
last_part = value.get("value")[-1]
|
||||
variables_map.pop(last_part, None)
|
||||
|
||||
all_second_step_variables = list(variables_map.values())
|
||||
|
||||
for item in all_second_step_variables:
|
||||
if item.get("belong_to_node_id") == start_node_id or item.get("belong_to_node_id") == "shared":
|
||||
variables.append(RagPipelineVariableEntity.model_validate(item))
|
||||
|
||||
return variables
|
||||
Reference in New Issue
Block a user