This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

View File

@@ -0,0 +1,158 @@
import contextlib
from copy import deepcopy
from typing import Any
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
ToolProviderType,
)
class ToolParameterConfigurationManager:
"""
Tool parameter configuration manager
"""
tenant_id: str
tool_runtime: Tool
provider_name: str
provider_type: ToolProviderType
identity_id: str
def __init__(
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
):
self.tenant_id = tenant_id
self.tool_runtime = tool_runtime
self.provider_name = provider_name
self.provider_type = provider_type
self.identity_id = identity_id
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
deep copy parameters
"""
return deepcopy(parameters)
def _merge_parameters(self) -> list[ToolParameter]:
"""
merge parameters
"""
# get tool parameters
tool_parameters = self.tool_runtime.entity.parameters or []
# get tool runtime parameters
runtime_parameters = self.tool_runtime.get_runtime_parameters()
# override parameters
current_parameters = tool_parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
return current_parameters
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
mask tool parameters
return a deep copy of parameters with masked values
"""
parameters = self._deep_copy(parameters)
# override parameters
current_parameters = self._merge_parameters()
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
if len(parameters[parameter.name]) > 6:
parameters[parameter.name] = (
parameters[parameter.name][:2]
+ "*" * (len(parameters[parameter.name]) - 4)
+ parameters[parameter.name][-2:]
)
else:
parameters[parameter.name] = "*" * len(parameters[parameter.name])
return parameters
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
encrypt tool parameters with tenant id
return a deep copy of parameters with encrypted values
"""
# override parameters
current_parameters = self._merge_parameters()
parameters = self._deep_copy(parameters)
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
parameters[parameter.name] = encrypted
return parameters
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
decrypt tool parameters with tenant id
return a deep copy of parameters with decrypted values
"""
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type.value}.{self.provider_name}",
tool_name=self.tool_runtime.entity.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id,
)
cached_parameters = cache.get()
if cached_parameters:
return cached_parameters
# override parameters
current_parameters = self._merge_parameters()
has_secret_input = False
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
has_secret_input = True
with contextlib.suppress(Exception):
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
if has_secret_input:
cache.set(parameters)
return parameters
def delete_tool_parameters_cache(self):
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type.value}.{self.provider_name}",
tool_name=self.tool_runtime.entity.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id,
)
cache.delete()

View File

@@ -0,0 +1,200 @@
import threading
from typing import Any
from flask import Flask, current_app
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
class DatasetMultiRetrieverToolInput(BaseModel):
query: str = Field(..., description="dataset multi retriever and rerank")
class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying multi dataset."""
name: str = "dataset_"
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
description: str = "dataset multi retriever and rerank. "
dataset_ids: list[str]
reranking_provider_name: str
reranking_model_name: str
@classmethod
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
return cls(
name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs
)
def _run(self, query: str) -> str:
threads = []
all_documents: list[RagDocument] = []
for dataset_id in self.dataset_ids:
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"all_documents": all_documents,
"hit_callbacks": self.hit_callbacks,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
# do rerank for searched documents
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=self.reranking_provider_name,
model_type=ModelType.RERANK,
model=self.reranking_model_name,
)
rerank_runner = RerankModelRunner(rerank_model_instance)
all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(all_documents)
document_score_list = {}
for item in all_documents:
if item.metadata and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
segments = db.session.scalars(document_segment_stmt).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
else:
document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list: list[RetrievalSourceMetadata] = []
resource_number = 1
for segment in sorted_segments:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document_stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
document = db.session.scalar(document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
position=resource_number,
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=self.retriever_from,
score=document_score_list.get(segment.index_node_id),
doc_metadata=document.doc_metadata,
)
if self.retriever_from == "dev":
source.hit_count = segment.hit_count
source.word_count = segment.word_count
source.segment_position = segment.position
source.index_node_hash = segment.index_node_hash
if segment.answer:
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source.content = segment.content
context_list.append(source)
resource_number += 1
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
return ""
def _retriever(
self,
flask_app: Flask,
dataset_id: str,
query: str,
all_documents: list,
hit_callbacks: list[DatasetIndexToolCallbackHandler],
):
with flask_app.app_context():
stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
return []
for hit_callback in hit_callbacks:
hit_callback.on_query(query, dataset.id)
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 4,
)
if documents:
all_documents.extend(documents)
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 4,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)
all_documents.extend(documents)

View File

@@ -0,0 +1,31 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel, ConfigDict
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
class DatasetRetrieverBaseTool(BaseModel, ABC):
"""Tool for querying a Dataset."""
name: str = "dataset"
description: str = "use this to retrieve a dataset. "
tenant_id: str
top_k: int = 4
score_threshold: float | None = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str
model_config = ConfigDict(arbitrary_types_allowed=True)
def run(self, query: str) -> str:
"""Use the tool."""
return self._run(query)
@abstractmethod
def _run(self, query: str) -> str:
"""Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None
to child implementations to enable tracing,
"""

View File

@@ -0,0 +1,234 @@
from typing import Any, cast
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"reranking_mode": "reranking_model",
"top_k": 2,
"score_threshold_enabled": False,
}
class DatasetRetrieverToolInput(BaseModel):
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
class DatasetRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying a Dataset."""
name: str = "dataset"
args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
dataset_id: str
user_id: str | None = None
retrieve_config: DatasetRetrieveConfigEntity
inputs: dict
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description
if not description:
description = "useful for when you want to answer queries about the " + dataset.name
description = description.replace("\n", "").replace("\r", "")
return cls(
name=f"dataset_{dataset.id.replace('-', '_')}",
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
description=description,
**kwargs,
)
def _run(self, query: str) -> str:
dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id)
dataset = db.session.scalar(dataset_stmt)
if not dataset:
return ""
for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id)
dataset_retrieval = DatasetRetrieval()
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
[dataset.id],
query,
self.tenant_id,
self.user_id or "unknown",
cast(str, self.retrieve_config.metadata_filtering_mode),
cast(ModelConfig, self.retrieve_config.metadata_model_config),
self.retrieve_config.metadata_filtering_conditions,
self.inputs,
)
if metadata_filter_document_ids:
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
else:
document_ids_filter = None
if dataset.provider == "external":
results: list[RetrievalDocument] = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
metadata_condition=metadata_condition,
)
for external_document in external_documents:
document = RetrievalDocument(
page_content=external_document.get("content"),
metadata=external_document.get("metadata"),
provider="external",
)
if document.metadata is not None:
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset.id
document.metadata["dataset_name"] = dataset.name
results.append(document)
# deal with external documents
context_list: list[RetrievalSourceMetadata] = []
for position, item in enumerate(results, start=1):
if item.metadata is not None:
source = RetrievalSourceMetadata(
position=position,
dataset_id=item.metadata.get("dataset_id"),
dataset_name=item.metadata.get("dataset_name"),
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
document_name=item.metadata.get("title"),
data_source_type="external",
retriever_from=self.retriever_from,
score=item.metadata.get("score"),
title=item.metadata.get("title"),
content=item.page_content,
)
context_list.append(source)
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join([item.page_content for item in results]))
else:
if metadata_condition and not document_ids_filter:
return ""
# get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
retrieval_resource_list: list[RetrievalSourceMetadata] = []
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
document_ids_filter=document_ids_filter,
)
return str("\n".join([document.page_content for document in documents]))
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model")
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights"),
document_ids_filter=document_ids_filter,
)
else:
documents = []
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
if item.metadata is not None and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list: list[DocumentContext] = []
records = RetrievalService.format_retrieval_documents(documents)
if records:
for record in records:
segment = record.segment
if segment.answer:
document_context_list.append(
DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=record.score,
)
)
else:
document_context_list.append(
DocumentContext(
content=segment.get_sign_content(),
score=record.score,
)
)
if self.return_resource:
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=self.retriever_from,
score=record.score or 0.0,
doc_metadata=document.doc_metadata,
)
if self.retriever_from == "dev":
source.hit_count = segment.hit_count
source.word_count = segment.word_count
source.segment_position = segment.position
source.index_node_hash = segment.index_node_hash
if segment.answer:
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source.content = segment.content
retrieval_resource_list.append(source)
if self.return_resource and retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list,
key=lambda x: x.score or 0.0,
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
item.position = position # type: ignore
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
return str("\n".join([document_context.content for document_context in document_context_list]))
return ""

View File

@@ -0,0 +1,136 @@
from collections.abc import Generator
from typing import Any
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolDescription,
ToolEntity,
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
class DatasetRetrieverTool(Tool):
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool):
super().__init__(entity, runtime)
self.retrieval_tool = retrieval_tool
@staticmethod
def get_dataset_tools(
tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity | None,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: dict,
) -> list["DatasetRetrieverTool"]:
"""
get dataset tool
"""
# check if retrieve_config is valid
if dataset_ids is None or len(dataset_ids) == 0:
return []
if retrieve_config is None:
return []
feature = DatasetRetrieval()
# save original retrieve strategy, and set retrieve strategy to SINGLE
# Agent only support SINGLE mode
original_retriever_mode = retrieve_config.retrieve_strategy
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
retrieval_tools = feature.to_dataset_retriever_tool(
tenant_id=tenant_id,
dataset_ids=dataset_ids,
retrieve_config=retrieve_config,
return_resource=return_resource,
invoke_from=invoke_from,
hit_callback=hit_callback,
user_id=user_id,
inputs=inputs,
)
if retrieval_tools is None or len(retrieval_tools) == 0:
return []
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode
# convert retrieval tools to Tools
tools = []
for retrieval_tool in retrieval_tools:
tool = DatasetRetrieverTool(
retrieval_tool=retrieval_tool,
entity=ToolEntity(
identity=ToolIdentity(
provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
),
parameters=[],
description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
),
runtime=ToolRuntime(tenant_id=tenant_id),
)
tools.append(tool)
return tools
def get_runtime_parameters(
self,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> list[ToolParameter]:
return [
ToolParameter(
name="query",
label=I18nObject(en_US="", zh_Hans=""),
human_description=I18nObject(en_US="", zh_Hans=""),
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
llm_description="Query for the dataset to be used to retrieve the dataset.",
required=True,
default="",
placeholder=I18nObject(en_US="", zh_Hans=""),
),
]
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.DATASET_RETRIEVAL
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke dataset retriever tool
"""
query = tool_parameters.get("query")
if not query:
yield self.create_text_message(text="please input query")
else:
# invoke dataset retriever tool
result = self.retrieval_tool.run(query=query)
yield self.create_text_message(text=result)
def validate_credentials(
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
) -> str | None:
"""
validate the credentials for dataset retriever tool
"""
pass

View File

@@ -0,0 +1,32 @@
# Import generic components from provider_encryption module
from core.helper.provider_encryption import (
ProviderConfigCache,
ProviderConfigEncrypter,
create_provider_encrypter,
)
# Re-export for backward compatibility
__all__ = [
"ProviderConfigCache",
"ProviderConfigEncrypter",
"create_provider_encrypter",
"create_tool_provider_encrypter",
]
# Tool-specific imports
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.tools.__base.tool_provider import ToolProviderController
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
cache = SingletonProviderCredentialsCache(
tenant_id=tenant_id,
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
)
encrypt = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_config_cache=cache,
)
return encrypt, cache

View File

@@ -0,0 +1,168 @@
import logging
from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
from mimetypes import guess_extension
from uuid import UUID
import numpy as np
import pytz
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
from libs.login import current_user
from models import Account
logger = logging.getLogger(__name__)
def safe_json_value(v):
if isinstance(v, datetime):
tz_name = "UTC"
if isinstance(current_user, Account) and current_user.timezone is not None:
tz_name = current_user.timezone
return v.astimezone(pytz.timezone(tz_name)).isoformat()
elif isinstance(v, date):
return v.isoformat()
elif isinstance(v, UUID):
return str(v)
elif isinstance(v, Decimal):
return float(v)
elif isinstance(v, bytes):
try:
return v.decode("utf-8")
except UnicodeDecodeError:
return v.hex()
elif isinstance(v, memoryview):
return v.tobytes().hex()
elif isinstance(v, np.ndarray):
return v.tolist()
elif isinstance(v, dict):
return safe_json_dict(v)
elif isinstance(v, list | tuple | set):
return [safe_json_value(i) for i in v]
else:
return v
def safe_json_dict(d: dict):
if not isinstance(d, dict):
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
return {k: safe_json_value(v) for k, v in d.items()}
class ToolFileMessageTransformer:
@classmethod
def transform_tool_invoke_messages(
cls,
messages: Generator[ToolInvokeMessage, None, None],
user_id: str,
tenant_id: str,
conversation_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
Transform tool message and handle file download
"""
for message in messages:
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
yield message
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(
message.message, ToolInvokeMessage.TextMessage
):
# try to download image
try:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
tool_file_manager = ToolFileManager()
tool_file = tool_file_manager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=message.message.text,
conversation_id=conversation_id,
)
url = f"/files/tools/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}"
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
meta=message.meta.copy() if message.meta is not None else {},
)
except Exception as e:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(
text=f"Failed to download image: {message.message.text}: {e}"
),
meta=message.meta.copy() if message.meta is not None else {},
)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage
meta = message.meta or {}
mimetype = meta.get("mime_type", "application/octet-stream")
# get filename from meta
filename = meta.get("filename", None)
# if message is str, encode it to bytes
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
raise ValueError("unexpected message type")
assert isinstance(message.message.blob, bytes)
tool_file_manager = ToolFileManager()
tool_file = tool_file_manager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_binary=message.message.blob,
mimetype=mimetype,
filename=filename,
)
url = cls.get_tool_file_url(tool_file_id=tool_file.id, extension=guess_extension(tool_file.mimetype))
# check if file is image
if "image" in mimetype:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BINARY_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
elif message.type == ToolInvokeMessage.MessageType.FILE:
meta = message.meta or {}
file = meta.get("file", None)
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=ToolInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield message
elif message.type == ToolInvokeMessage.MessageType.JSON:
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
message.message.json_object = safe_json_value(message.message.json_object)
yield message
else:
yield message
@classmethod
def get_tool_file_url(cls, tool_file_id: str, extension: str | None) -> str:
return f"/files/tools/{tool_file_id}{extension or '.bin'}"

View File

@@ -0,0 +1,167 @@
"""
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
Therefore, a model manager is needed to list/invoke/validate models.
"""
import json
from decimal import Decimal
from typing import cast
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.tools import ToolModelInvoke
class InvokeModelError(Exception):
pass
class ModelInvocationUtils:
@staticmethod
def get_max_llm_context_tokens(
tenant_id: str,
) -> int:
"""
get max llm context tokens of the model
"""
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
if not model_instance:
raise InvokeModelError("Model not found")
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if not schema:
raise InvokeModelError("No model schema found")
max_tokens: int | None = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
if max_tokens is None:
return 2048
return max_tokens
@staticmethod
def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
"""
calculate tokens from prompt messages and model parameters
"""
# get model instance
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
if not model_instance:
raise InvokeModelError("Model not found")
# get tokens
tokens = model_instance.get_llm_num_tokens(prompt_messages)
return tokens
@staticmethod
def invoke(
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
invoke model with parameters in user's own context
:param user_id: user id
:param tenant_id: tenant id, the tenant id of the creator of the tool
:param tool_type: tool type
:param tool_name: tool name
:param prompt_messages: prompt messages
:return: AssistantPromptMessage
"""
# get model manager
model_manager = ModelManager()
# get model instance
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
# get prompt tokens
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
model_parameters = {
"temperature": 0.8,
"top_p": 0.8,
}
# create tool model invoke
tool_model_invoke = ToolModelInvoke(
user_id=user_id,
tenant_id=tenant_id,
provider=model_instance.provider,
tool_type=tool_type,
tool_name=tool_name,
model_parameters=json.dumps(model_parameters),
prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
model_response="",
prompt_tokens=prompt_tokens,
answer_tokens=0,
answer_unit_price=Decimal(),
answer_price_unit=Decimal(),
provider_response_latency=0,
total_price=Decimal(),
currency="USD",
)
db.session.add(tool_model_invoke)
db.session.commit()
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
)
except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}")
except InvokeBadRequestError as e:
raise InvokeModelError(f"Invoke bad request error: {e}")
except InvokeConnectionError as e:
raise InvokeModelError(f"Invoke connection error: {e}")
except InvokeAuthorizationError as e:
raise InvokeModelError("Invoke authorization error")
except InvokeServerUnavailableError as e:
raise InvokeModelError(f"Invoke server unavailable error: {e}")
except Exception as e:
raise InvokeModelError(f"Invoke error: {e}")
# update tool model invoke
tool_model_invoke.model_response = str(response.message.content)
if response.usage:
tool_model_invoke.answer_tokens = response.usage.completion_tokens
tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
tool_model_invoke.provider_response_latency = response.usage.latency
tool_model_invoke.total_price = response.usage.total_price
tool_model_invoke.currency = response.usage.currency
db.session.commit()
return response

View File

@@ -0,0 +1,453 @@
import re
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
from typing import Any
import httpx
from flask import request
from yaml import YAMLError, safe_load
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
# set description to extra_info
extra_info["description"] = openapi["info"].get("description", "")
if len(openapi["servers"]) == 0:
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
server_url = openapi["servers"][0]["url"]
request_env = request.headers.get("X-Request-Env")
if request_env:
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
server_url = matched_servers[0] if matched_servers else server_url
# list all interfaces
interfaces = []
for path, path_item in openapi["paths"].items():
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
for method in methods:
if method in path_item:
interfaces.append(
{
"path": path,
"method": method,
"operation": path_item[method],
}
)
# get all parameters
bundles = []
for interface in interfaces:
# convert parameters
parameters = []
if "parameters" in interface["operation"]:
for i, parameter in enumerate(interface["operation"]["parameters"]):
if "$ref" in parameter:
root = openapi
reference = parameter["$ref"].split("/")[1:]
for ref in reference:
root = root[ref]
interface["operation"]["parameters"][i] = root
for parameter in interface["operation"]["parameters"]:
# Handle complex type defaults that are not supported by PluginParameter
default_value = None
if "schema" in parameter and "default" in parameter["schema"]:
default_value = ApiBasedToolSchemaParser._sanitize_default_value(parameter["schema"]["default"])
tool_parameter = ToolParameter(
name=parameter["name"],
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
human_description=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=parameter.get("required", False),
form=ToolParameter.ToolParameterForm.LLM,
llm_description=parameter.get("description"),
default=default_value,
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
if typ:
tool_parameter.type = typ
parameters.append(tool_parameter)
# create tool bundle
# check if there is a request body
if "requestBody" in interface["operation"]:
request_body = interface["operation"]["requestBody"]
if "content" in request_body:
for content_type, content in request_body["content"].items():
# if there is a reference, get the reference and overwrite the content
if "schema" not in content:
continue
if "$ref" in content["schema"]:
# get the reference
root = openapi
reference = content["schema"]["$ref"].split("/")[1:]
for ref in reference:
root = root[ref]
# overwrite the content
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
# handle allOf reference in schema properties
for prop_dict in root.get("properties", {}).values():
for item in prop_dict.get("allOf", []):
if "$ref" in item:
ref_schema = openapi
reference = item["$ref"].split("/")[1:]
for ref in reference:
ref_schema = ref_schema[ref]
else:
ref_schema = item
for key, value in ref_schema.items():
if isinstance(value, list):
if key not in prop_dict:
prop_dict[key] = []
# extends list field
if isinstance(prop_dict[key], list):
prop_dict[key].extend(value)
elif key not in prop_dict:
# add new field
prop_dict[key] = value
if "allOf" in prop_dict:
del prop_dict["allOf"]
# parse body parameters
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
required = body_schema.get("required", [])
properties = body_schema.get("properties", {})
for name, property in properties.items():
# Handle complex type defaults that are not supported by PluginParameter
default_value = ApiBasedToolSchemaParser._sanitize_default_value(
property.get("default", None)
)
tool = ToolParameter(
name=name,
label=I18nObject(en_US=name, zh_Hans=name),
human_description=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=name in required,
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=default_value,
placeholder=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
if typ:
tool.type = typ
parameters.append(tool)
# check if parameters is duplicated
parameters_count = {}
for parameter in parameters:
if parameter.name not in parameters_count:
parameters_count[parameter.name] = 0
parameters_count[parameter.name] += 1
for name, count in parameters_count.items():
if count > 1:
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
# check if there is a operation id, use $path_$method as operation id if not
if "operationId" not in interface["operation"]:
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = interface["path"]
if interface["path"].startswith("/"):
path = interface["path"][1:]
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
if not path:
path = "<root>"
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
bundles.append(
ApiToolBundle(
server_url=server_url + interface["path"],
method=interface["method"],
summary=interface["operation"]["description"]
if "description" in interface["operation"]
else interface["operation"].get("summary", None),
operation_id=interface["operation"]["operationId"],
parameters=parameters,
author="",
icon=None,
openapi=interface["operation"],
)
)
return bundles
@staticmethod
def _sanitize_default_value(value):
"""
Sanitize default values for PluginParameter compatibility.
Complex types (list, dict) are converted to None to avoid validation errors.
Args:
value: The default value from OpenAPI schema
Returns:
None for complex types (list, dict), otherwise the original value
"""
if isinstance(value, (list, dict)):
return None
return value
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
parameter = parameter or {}
typ: str | None = None
if parameter.get("format") == "binary":
return ToolParameter.ToolParameterType.FILE
if "type" in parameter:
typ = parameter["type"]
elif "schema" in parameter and "type" in parameter["schema"]:
typ = parameter["schema"]["type"]
if typ in {"integer", "number"}:
return ToolParameter.ToolParameterType.NUMBER
elif typ == "boolean":
return ToolParameter.ToolParameterType.BOOLEAN
elif typ == "string":
return ToolParameter.ToolParameterType.STRING
elif typ == "array":
items = parameter.get("items") or parameter.get("schema", {}).get("items")
if items and items.get("format") == "binary":
return ToolParameter.ToolParameterType.FILES
else:
# For regular arrays, return ARRAY type instead of None
return ToolParameter.ToolParameterType.ARRAY
else:
return None
@staticmethod
def parse_openapi_yaml_to_tool_bundle(
yaml: str, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
"""
parse openapi yaml to tool bundle
:param yaml: the yaml string
:param extra_info: the extra info
:param warning: the warning message
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
openapi: dict = safe_load(yaml)
if openapi is None:
raise ToolApiSchemaError("Invalid openapi yaml.")
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
) -> dict[str, Any]:
warning = warning or {}
"""
parse swagger to openapi
:param swagger: the swagger dict
:return: the openapi dict
"""
# convert swagger to openapi
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
servers = swagger.get("servers", [])
if len(servers) == 0:
raise ToolApiSchemaError("No server found in the swagger yaml.")
converted_openapi: dict[str, Any] = {
"openapi": "3.0.0",
"info": {
"title": info.get("title", "Swagger"),
"description": info.get("description", "Swagger"),
"version": info.get("version", "1.0.0"),
},
"servers": swagger["servers"],
"paths": {},
"components": {"schemas": {}},
}
# check paths
if "paths" not in swagger or len(swagger["paths"]) == 0:
raise ToolApiSchemaError("No paths found in the swagger yaml.")
# convert paths
for path, path_item in swagger["paths"].items():
converted_openapi["paths"][path] = {}
for method, operation in path_item.items():
if "operationId" not in operation:
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
if ("summary" not in operation or len(operation["summary"]) == 0) and (
"description" not in operation or len(operation["description"]) == 0
):
if warning is not None:
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
converted_openapi["paths"][path][method] = {
"operationId": operation["operationId"],
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": operation.get("parameters", []),
"responses": operation.get("responses", {}),
}
if "requestBody" in operation:
converted_openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
# convert definitions
if "definitions" in swagger:
for name, definition in swagger["definitions"].items():
converted_openapi["components"]["schemas"][name] = definition
return converted_openapi
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(
json: str, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
"""
parse openapi plugin yaml to tool bundle
:param json: the json string
:param extra_info: the extra info
:param warning: the warning message
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
try:
openai_plugin = json_loads(json)
api = openai_plugin["api"]
api_url = api["url"]
api_type = api["type"]
except JSONDecodeError:
raise ToolProviderNotFoundError("Invalid openai plugin json.")
if api_type != "openapi":
raise ToolNotSupportedError("Only openapi is supported now.")
# get openapi yaml
response = httpx.get(
api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5
)
try:
if response.status_code != 200:
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
response.text, extra_info=extra_info, warning=warning
)
finally:
response.close()
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
) -> tuple[list[ApiToolBundle], str]:
"""
auto parse to tool bundle
:param content: the content
:param extra_info: the extra info
:param warning: the warning message
:return: tools bundle, schema_type
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
content = content.strip()
loaded_content = None
json_error = None
yaml_error = None
try:
loaded_content = json_loads(content)
except JSONDecodeError as e:
json_error = e
if loaded_content is None:
try:
loaded_content = safe_load(content)
except YAMLError as e:
yaml_error = e
if loaded_content is None:
raise ToolApiSchemaError(
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
f" yaml error: {str(yaml_error)}"
)
swagger_error = None
openapi_error = None
openapi_plugin_error = None
schema_type = None
try:
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.OPENAPI
return openapi, schema_type
except ToolApiSchemaError as e:
openapi_error = e
# openai parse error, fallback to swagger
try:
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.SWAGGER
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
converted_swagger, extra_info=extra_info, warning=warning
), schema_type
except ToolApiSchemaError as e:
swagger_error = e
# swagger parse error, fallback to openai plugin
try:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
json_dumps(loaded_content), extra_info=extra_info, warning=warning
)
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN
except ToolNotSupportedError as e:
# maybe it's not plugin at all
openapi_plugin_error = e
raise ToolApiSchemaError(
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
f" openapi plugin error: {str(openapi_plugin_error)}"
)

View File

@@ -0,0 +1,187 @@
import base64
import hashlib
import logging
from collections.abc import Mapping
from typing import Any
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from pydantic import TypeAdapter
from configs import dify_config
logger = logging.getLogger(__name__)
class OAuthEncryptionError(Exception):
"""OAuth encryption/decryption specific error"""
pass
class SystemOAuthEncrypter:
"""
A simple OAuth parameters encrypter using AES-CBC encryption.
This class provides methods to encrypt and decrypt OAuth parameters
using AES-CBC mode with a key derived from the application's SECRET_KEY.
"""
def __init__(self, secret_key: str | None = None):
"""
Initialize the OAuth encrypter.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Raises:
ValueError: If SECRET_KEY is not configured or empty
"""
secret_key = secret_key or dify_config.SECRET_KEY or ""
# Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters.
Args:
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
Returns:
Base64-encoded encrypted string
Raises:
OAuthEncryptionError: If encryption fails
ValueError: If oauth_params is invalid
"""
try:
# Generate random IV (16 bytes)
iv = get_random_bytes(16)
# Create AES cipher (CBC mode)
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data
combined = iv + encrypted_data
# Return base64 encoded string
return base64.b64encode(combined).decode()
except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Raises:
OAuthEncryptionError: If decryption fails
ValueError: If encrypted_data is invalid
"""
if not isinstance(encrypted_data, str):
raise ValueError("encrypted_data must be a string")
if not encrypted_data:
raise ValueError("encrypted_data cannot be empty")
try:
# Base64 decode
combined = base64.b64decode(encrypted_data)
# Check minimum length (IV + at least one AES block)
if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
raise ValueError("Invalid encrypted data format")
# Separate IV and encrypted data
iv = combined[:16]
encrypted_data_bytes = combined[16:]
# Create AES cipher
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Decrypt data
decrypted_data = cipher.decrypt(encrypted_data_bytes)
unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
if not isinstance(oauth_params, dict):
raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params
except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
# Factory function for creating encrypter instances
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
"""
Create an OAuth encrypter instance.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Returns:
SystemOAuthEncrypter instance
"""
return SystemOAuthEncrypter(secret_key=secret_key)
# Global encrypter instance (for backward compatibility)
_oauth_encrypter: SystemOAuthEncrypter | None = None
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
"""
Get the global OAuth encrypter instance.
Returns:
SystemOAuthEncrypter instance
"""
global _oauth_encrypter
if _oauth_encrypter is None:
_oauth_encrypter = SystemOAuthEncrypter()
return _oauth_encrypter
# Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters using the global encrypter.
Args:
oauth_params: OAuth parameters dictionary
Returns:
Base64-encoded encrypted string
"""
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters using the global encrypter.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
"""
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)

View File

@@ -0,0 +1,17 @@
import re
def remove_leading_symbols(text: str) -> str:
"""
Remove leading punctuation or symbols from the given text.
Args:
text (str): The input text to process.
Returns:
str: The text with leading punctuation or symbols removed.
"""
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
return re.sub(pattern, "", text)

View File

@@ -0,0 +1,11 @@
import uuid
def is_valid_uuid(uuid_str: str | None) -> bool:
if uuid_str is None or len(uuid_str) == 0:
return False
try:
uuid.UUID(uuid_str)
return True
except Exception:
return False

View File

@@ -0,0 +1,128 @@
import mimetypes
import re
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, cast
from urllib.parse import unquote
import chardet
import cloudscraper
from readabilipy import simple_json_from_html_string
from core.helper import ssrf_proxy
from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor
FULL_TEMPLATE = """
TITLE: {title}
AUTHOR: {author}
TEXT:
{text}
"""
def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
return text[cursor : cursor + max_length]
def get_url(url: str, user_agent: str | None = None) -> str:
"""Fetch URL and return the contents as a string."""
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
" Chrome/91.0.4472.124 Safari/537.36"
}
if user_agent:
headers["User-Agent"] = user_agent
main_content_type = None
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
if response.status_code == 200:
# check content-type
content_type = response.headers.get("Content-Type")
if content_type:
main_content_type = response.headers.get("Content-Type").split(";")[0].strip()
else:
content_disposition = response.headers.get("Content-Disposition", "")
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
extension = re.search(r"\.(\w+)$", filename)
if extension:
main_content_type = mimetypes.guess_type(filename)[0]
if main_content_type not in supported_content_types:
return f"Unsupported content-type [{main_content_type}] of URL."
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
scraper = cloudscraper.create_scraper()
scraper.perform_request = ssrf_proxy.make_request
response = scraper.get(url, headers=headers, timeout=(120, 300))
if response.status_code != 200:
return f"URL returned status code {response.status_code}."
# Detect encoding using chardet
detected_encoding = chardet.detect(response.content)
encoding = detected_encoding["encoding"]
if encoding:
try:
content = response.content.decode(encoding)
except (UnicodeDecodeError, TypeError):
content = response.text
else:
content = response.text
article = extract_using_readabilipy(content)
if not article.text:
return ""
res = FULL_TEMPLATE.format(
title=article.title,
author=article.author,
text=article.text,
)
return res
@dataclass
class Article:
title: str
author: str
text: Sequence[dict]
def extract_using_readabilipy(html: str):
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
article = Article(
title=json_article.get("title") or "",
author=json_article.get("byline") or "",
text=json_article.get("plain_text") or [],
)
return article
def get_image_upload_file_ids(content):
pattern = r"!\[image\]\((https?://.*?(file-preview|image-preview))\)"
matches = re.findall(pattern, content)
image_upload_file_ids = []
for match in matches:
if match[1] == "file-preview":
content_pattern = r"files/([^/]+)/file-preview"
else:
content_pattern = r"files/([^/]+)/image-preview"
content_match = re.search(content_pattern, match[0])
if content_match:
image_upload_file_id = content_match.group(1)
image_upload_file_ids.append(image_upload_file_id)
return image_upload_file_ids

View File

@@ -0,0 +1,43 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
"""
get workflow graph variables
"""
nodes = graph.get("nodes", [])
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
if not start_node:
return []
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
):
"""
check is synced
raise ValueError if not synced
"""
variable_names = [variable.variable for variable in variables]
if len(tool_configurations) != len(variables):
raise ValueError("parameter configuration mismatch, please republish the tool to update")
for parameter in tool_configurations:
if parameter.name not in variable_names:
raise ValueError("parameter configuration mismatch, please republish the tool to update")

View File

@@ -0,0 +1,33 @@
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any
import yaml
from yaml import YAMLError
logger = logging.getLogger(__name__)
def _load_yaml_file(*, file_path: str):
if not file_path or not Path(file_path).exists():
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content
except Exception as e:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
@lru_cache(maxsize=128)
def load_yaml_file_cached(file_path: str) -> Any:
"""
Cached version of load_yaml_file for static configuration files.
Only use for files that don't change during runtime (e.g., position files)
:param file_path: the path of the YAML file
:return: an object of the YAML content
"""
return _load_yaml_file(file_path=file_path)