dify
This commit is contained in:
0
dify/api/core/tools/utils/__init__.py
Normal file
0
dify/api/core/tools/utils/__init__.py
Normal file
158
dify/api/core/tools/utils/configuration.py
Normal file
158
dify/api/core/tools/utils/configuration.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
|
||||
|
||||
class ToolParameterConfigurationManager:
|
||||
"""
|
||||
Tool parameter configuration manager
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
tool_runtime: Tool
|
||||
provider_name: str
|
||||
provider_type: ToolProviderType
|
||||
identity_id: str
|
||||
|
||||
def __init__(
|
||||
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.tool_runtime = tool_runtime
|
||||
self.provider_name = provider_name
|
||||
self.provider_type = provider_type
|
||||
self.identity_id = identity_id
|
||||
|
||||
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
deep copy parameters
|
||||
"""
|
||||
return deepcopy(parameters)
|
||||
|
||||
def _merge_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
merge parameters
|
||||
"""
|
||||
# get tool parameters
|
||||
tool_parameters = self.tool_runtime.entity.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = self.tool_runtime.get_runtime_parameters()
|
||||
# override parameters
|
||||
current_parameters = tool_parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
found = False
|
||||
for index, parameter in enumerate(current_parameters):
|
||||
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||
current_parameters[index] = runtime_parameter
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
return current_parameters
|
||||
|
||||
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool parameters
|
||||
|
||||
return a deep copy of parameters with masked values
|
||||
"""
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
if len(parameters[parameter.name]) > 6:
|
||||
parameters[parameter.name] = (
|
||||
parameters[parameter.name][:2]
|
||||
+ "*" * (len(parameters[parameter.name]) - 4)
|
||||
+ parameters[parameter.name][-2:]
|
||||
)
|
||||
else:
|
||||
parameters[parameter.name] = "*" * len(parameters[parameter.name])
|
||||
|
||||
return parameters
|
||||
|
||||
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
encrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with encrypted values
|
||||
"""
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
parameters[parameter.name] = encrypted
|
||||
|
||||
return parameters
|
||||
|
||||
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with decrypted values
|
||||
"""
|
||||
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cached_parameters = cache.get()
|
||||
if cached_parameters:
|
||||
return cached_parameters
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
has_secret_input = False
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
has_secret_input = True
|
||||
with contextlib.suppress(Exception):
|
||||
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
|
||||
if has_secret_input:
|
||||
cache.set(parameters)
|
||||
|
||||
return parameters
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cache.delete()
|
||||
@@ -0,0 +1,200 @@
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class DatasetMultiRetrieverToolInput(BaseModel):
|
||||
query: str = Field(..., description="dataset multi retriever and rerank")
|
||||
|
||||
|
||||
class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
"""Tool for querying multi dataset."""
|
||||
|
||||
name: str = "dataset_"
|
||||
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
|
||||
description: str = "dataset multi retriever and rerank. "
|
||||
dataset_ids: list[str]
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
|
||||
return cls(
|
||||
name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs
|
||||
)
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
threads = []
|
||||
all_documents: list[RagDocument] = []
|
||||
for dataset_id in self.dataset_ids:
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"all_documents": all_documents,
|
||||
"hit_callbacks": self.hit_callbacks,
|
||||
},
|
||||
)
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# do rerank for searched documents
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=self.reranking_provider_name,
|
||||
model_type=ModelType.RERANK,
|
||||
model=self.reranking_model_name,
|
||||
)
|
||||
|
||||
rerank_runner = RerankModelRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(all_documents)
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
if item.metadata and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
)
|
||||
segments = db.session.scalars(document_segment_stmt).all()
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(
|
||||
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
|
||||
)
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
|
||||
else:
|
||||
document_context_list.append(segment.get_sign_content())
|
||||
if self.return_resource:
|
||||
context_list: list[RetrievalSourceMetadata] = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document_stmt = select(Document).where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
document = db.session.scalar(document_stmt)
|
||||
if dataset and document:
|
||||
source = RetrievalSourceMetadata(
|
||||
position=resource_number,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=self.retriever_from,
|
||||
score=document_score_list.get(segment.index_node_id),
|
||||
doc_metadata=document.doc_metadata,
|
||||
)
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source.hit_count = segment.hit_count
|
||||
source.word_count = segment.word_count
|
||||
source.segment_position = segment.position
|
||||
source.index_node_hash = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source.content = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
return ""
|
||||
|
||||
def _retriever(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
all_documents: list,
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler],
|
||||
):
|
||||
with flask_app.app_context():
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
for hit_callback in hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 4,
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
if self.top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model["search_method"],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 4,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
@@ -0,0 +1,31 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
|
||||
|
||||
class DatasetRetrieverBaseTool(BaseModel, ABC):
|
||||
"""Tool for querying a Dataset."""
|
||||
|
||||
name: str = "dataset"
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
tenant_id: str
|
||||
top_k: int = 4
|
||||
score_threshold: float | None = None
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Use the tool."""
|
||||
return self._run(query)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, query: str) -> str:
|
||||
"""Use the tool.
|
||||
|
||||
Add run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
to child implementations to enable tracing,
|
||||
"""
|
||||
@@ -0,0 +1,234 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"reranking_mode": "reranking_model",
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class DatasetRetrieverToolInput(BaseModel):
|
||||
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
|
||||
|
||||
|
||||
class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
"""Tool for querying a Dataset."""
|
||||
|
||||
name: str = "dataset"
|
||||
args_schema: type[BaseModel] = DatasetRetrieverToolInput
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
dataset_id: str
|
||||
user_id: str | None = None
|
||||
retrieve_config: DatasetRetrieveConfigEntity
|
||||
inputs: dict
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = "useful for when you want to answer queries about the " + dataset.name
|
||||
|
||||
description = description.replace("\n", "").replace("\r", "")
|
||||
return cls(
|
||||
name=f"dataset_{dataset.id.replace('-', '_')}",
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
description=description,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
if not dataset:
|
||||
return ""
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
|
||||
[dataset.id],
|
||||
query,
|
||||
self.tenant_id,
|
||||
self.user_id or "unknown",
|
||||
cast(str, self.retrieve_config.metadata_filtering_mode),
|
||||
cast(ModelConfig, self.retrieve_config.metadata_model_config),
|
||||
self.retrieve_config.metadata_filtering_conditions,
|
||||
self.inputs,
|
||||
)
|
||||
if metadata_filter_document_ids:
|
||||
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
|
||||
else:
|
||||
document_ids_filter = None
|
||||
if dataset.provider == "external":
|
||||
results: list[RetrievalDocument] = []
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
external_retrieval_parameters=dataset.retrieval_model,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
for external_document in external_documents:
|
||||
document = RetrievalDocument(
|
||||
page_content=external_document.get("content"),
|
||||
metadata=external_document.get("metadata"),
|
||||
provider="external",
|
||||
)
|
||||
if document.metadata is not None:
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset.id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
# deal with external documents
|
||||
context_list: list[RetrievalSourceMetadata] = []
|
||||
for position, item in enumerate(results, start=1):
|
||||
if item.metadata is not None:
|
||||
source = RetrievalSourceMetadata(
|
||||
position=position,
|
||||
dataset_id=item.metadata.get("dataset_id"),
|
||||
dataset_name=item.metadata.get("dataset_name"),
|
||||
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
document_name=item.metadata.get("title"),
|
||||
data_source_type="external",
|
||||
retriever_from=self.retriever_from,
|
||||
score=item.metadata.get("score"),
|
||||
title=item.metadata.get("title"),
|
||||
content=item.page_content,
|
||||
)
|
||||
context_list.append(source)
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join([item.page_content for item in results]))
|
||||
else:
|
||||
if metadata_condition and not document_ids_filter:
|
||||
return ""
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
if self.top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
reranking_model=retrieval_model.get("reranking_model")
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights"),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
for item in documents:
|
||||
if item.metadata is not None and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
document_context_list: list[DocumentContext] = []
|
||||
records = RetrievalService.format_retrieval_documents(documents)
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
if segment.answer:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
else:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=segment.get_sign_content(),
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
|
||||
if self.return_resource:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset and document:
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=self.retriever_from,
|
||||
score=record.score or 0.0,
|
||||
doc_metadata=document.doc_metadata,
|
||||
)
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source.hit_count = segment.hit_count
|
||||
source.word_count = segment.word_count
|
||||
source.segment_position = segment.position
|
||||
source.index_node_hash = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source.content = segment.content
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
if self.return_resource and retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=lambda x: x.score or 0.0,
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
|
||||
item.position = position # type: ignore
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
return str("\n".join([document_context.content for document_context in document_context_list]))
|
||||
return ""
|
||||
136
dify/api/core/tools/utils/dataset_retriever_tool.py
Normal file
136
dify/api/core/tools/utils/dataset_retriever_tool.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
|
||||
|
||||
class DatasetRetrieverTool(Tool):
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool):
|
||||
super().__init__(entity, runtime)
|
||||
self.retrieval_tool = retrieval_tool
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_tools(
|
||||
tenant_id: str,
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity | None,
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
user_id: str,
|
||||
inputs: dict,
|
||||
) -> list["DatasetRetrieverTool"]:
|
||||
"""
|
||||
get dataset tool
|
||||
"""
|
||||
# check if retrieve_config is valid
|
||||
if dataset_ids is None or len(dataset_ids) == 0:
|
||||
return []
|
||||
if retrieve_config is None:
|
||||
return []
|
||||
|
||||
feature = DatasetRetrieval()
|
||||
|
||||
# save original retrieve strategy, and set retrieve strategy to SINGLE
|
||||
# Agent only support SINGLE mode
|
||||
original_retriever_mode = retrieve_config.retrieve_strategy
|
||||
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
|
||||
retrieval_tools = feature.to_dataset_retriever_tool(
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=retrieve_config,
|
||||
return_resource=return_resource,
|
||||
invoke_from=invoke_from,
|
||||
hit_callback=hit_callback,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
)
|
||||
if retrieval_tools is None or len(retrieval_tools) == 0:
|
||||
return []
|
||||
|
||||
# restore retrieve strategy
|
||||
retrieve_config.retrieve_strategy = original_retriever_mode
|
||||
|
||||
# convert retrieval tools to Tools
|
||||
tools = []
|
||||
for retrieval_tool in retrieval_tools:
|
||||
tool = DatasetRetrieverTool(
|
||||
retrieval_tool=retrieval_tool,
|
||||
entity=ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
|
||||
),
|
||||
parameters=[],
|
||||
description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
|
||||
),
|
||||
runtime=ToolRuntime(tenant_id=tenant_id),
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> list[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="query",
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
human_description=I18nObject(en_US="", zh_Hans=""),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description="Query for the dataset to be used to retrieve the dataset.",
|
||||
required=True,
|
||||
default="",
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
),
|
||||
]
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.DATASET_RETRIEVAL
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke dataset retriever tool
|
||||
"""
|
||||
query = tool_parameters.get("query")
|
||||
if not query:
|
||||
yield self.create_text_message(text="please input query")
|
||||
else:
|
||||
# invoke dataset retriever tool
|
||||
result = self.retrieval_tool.run(query=query)
|
||||
yield self.create_text_message(text=result)
|
||||
|
||||
def validate_credentials(
|
||||
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
|
||||
) -> str | None:
|
||||
"""
|
||||
validate the credentials for dataset retriever tool
|
||||
"""
|
||||
pass
|
||||
32
dify/api/core/tools/utils/encryption.py
Normal file
32
dify/api/core/tools/utils/encryption.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Import generic components from provider_encryption module
|
||||
from core.helper.provider_encryption import (
|
||||
ProviderConfigCache,
|
||||
ProviderConfigEncrypter,
|
||||
create_provider_encrypter,
|
||||
)
|
||||
|
||||
# Re-export for backward compatibility
|
||||
__all__ = [
|
||||
"ProviderConfigCache",
|
||||
"ProviderConfigEncrypter",
|
||||
"create_provider_encrypter",
|
||||
"create_tool_provider_encrypter",
|
||||
]
|
||||
|
||||
# Tool-specific imports
|
||||
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
|
||||
|
||||
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
|
||||
cache = SingletonProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=controller.provider_type.value,
|
||||
provider_identity=controller.entity.identity.name,
|
||||
)
|
||||
encrypt = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
|
||||
provider_config_cache=cache,
|
||||
)
|
||||
return encrypt, cache
|
||||
168
dify/api/core/tools/utils/message_transformer.py
Normal file
168
dify/api/core/tools/utils/message_transformer.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from mimetypes import guess_extension
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
import pytz
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_json_value(v):
|
||||
if isinstance(v, datetime):
|
||||
tz_name = "UTC"
|
||||
if isinstance(current_user, Account) and current_user.timezone is not None:
|
||||
tz_name = current_user.timezone
|
||||
return v.astimezone(pytz.timezone(tz_name)).isoformat()
|
||||
elif isinstance(v, date):
|
||||
return v.isoformat()
|
||||
elif isinstance(v, UUID):
|
||||
return str(v)
|
||||
elif isinstance(v, Decimal):
|
||||
return float(v)
|
||||
elif isinstance(v, bytes):
|
||||
try:
|
||||
return v.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return v.hex()
|
||||
elif isinstance(v, memoryview):
|
||||
return v.tobytes().hex()
|
||||
elif isinstance(v, np.ndarray):
|
||||
return v.tolist()
|
||||
elif isinstance(v, dict):
|
||||
return safe_json_dict(v)
|
||||
elif isinstance(v, list | tuple | set):
|
||||
return [safe_json_value(i) for i in v]
|
||||
else:
|
||||
return v
|
||||
|
||||
|
||||
def safe_json_dict(d: dict):
|
||||
if not isinstance(d, dict):
|
||||
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
|
||||
return {k: safe_json_value(v) for k, v in d.items()}
|
||||
|
||||
|
||||
class ToolFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_tool_invoke_messages(
|
||||
cls,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Transform tool message and handle file download
|
||||
"""
|
||||
for message in messages:
|
||||
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
|
||||
yield message
|
||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(
|
||||
message.message, ToolInvokeMessage.TextMessage
|
||||
):
|
||||
# try to download image
|
||||
try:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
tool_file_manager = ToolFileManager()
|
||||
tool_file = tool_file_manager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=message.message.text,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
url = f"/files/tools/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}"
|
||||
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
except Exception as e:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(
|
||||
text=f"Failed to download image: {message.message.text}: {e}"
|
||||
),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
meta = message.meta or {}
|
||||
|
||||
mimetype = meta.get("mime_type", "application/octet-stream")
|
||||
# get filename from meta
|
||||
filename = meta.get("filename", None)
|
||||
# if message is str, encode it to bytes
|
||||
|
||||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
tool_file_manager = ToolFileManager()
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_binary=message.message.blob,
|
||||
mimetype=mimetype,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
url = cls.get_tool_file_url(tool_file_id=tool_file.id, extension=guess_extension(tool_file.mimetype))
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
meta = message.meta or {}
|
||||
file = meta.get("file", None)
|
||||
if isinstance(file, File):
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file.related_id is not None
|
||||
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
|
||||
if file.type == FileType.IMAGE:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield message
|
||||
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
|
||||
message.message.json_object = safe_json_value(message.message.json_object)
|
||||
yield message
|
||||
else:
|
||||
yield message
|
||||
|
||||
@classmethod
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: str | None) -> str:
|
||||
return f"/files/tools/{tool_file_id}{extension or '.bin'}"
|
||||
167
dify/api/core/tools/utils/model_invocation_utils.py
Normal file
167
dify/api/core/tools/utils/model_invocation_utils.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
|
||||
|
||||
Therefore, a model manager is needed to list/invoke/validate models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from decimal import Decimal
|
||||
from typing import cast
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ToolModelInvoke
|
||||
|
||||
|
||||
class InvokeModelError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ModelInvocationUtils:
|
||||
@staticmethod
|
||||
def get_max_llm_context_tokens(
|
||||
tenant_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
get max llm context tokens of the model
|
||||
"""
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
if not model_instance:
|
||||
raise InvokeModelError("Model not found")
|
||||
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
|
||||
if not schema:
|
||||
raise InvokeModelError("No model schema found")
|
||||
|
||||
max_tokens: int | None = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
|
||||
if max_tokens is None:
|
||||
return 2048
|
||||
|
||||
return max_tokens
|
||||
|
||||
@staticmethod
|
||||
def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
|
||||
"""
|
||||
calculate tokens from prompt messages and model parameters
|
||||
"""
|
||||
|
||||
# get model instance
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
|
||||
|
||||
if not model_instance:
|
||||
raise InvokeModelError("Model not found")
|
||||
|
||||
# get tokens
|
||||
tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return tokens
|
||||
|
||||
@staticmethod
|
||||
def invoke(
|
||||
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
invoke model with parameters in user's own context
|
||||
|
||||
:param user_id: user id
|
||||
:param tenant_id: tenant id, the tenant id of the creator of the tool
|
||||
:param tool_type: tool type
|
||||
:param tool_name: tool name
|
||||
:param prompt_messages: prompt messages
|
||||
:return: AssistantPromptMessage
|
||||
"""
|
||||
|
||||
# get model manager
|
||||
model_manager = ModelManager()
|
||||
# get model instance
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
# get prompt tokens
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
model_parameters = {
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.8,
|
||||
}
|
||||
|
||||
# create tool model invoke
|
||||
tool_model_invoke = ToolModelInvoke(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=model_instance.provider,
|
||||
tool_type=tool_type,
|
||||
tool_name=tool_name,
|
||||
model_parameters=json.dumps(model_parameters),
|
||||
prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
|
||||
model_response="",
|
||||
prompt_tokens=prompt_tokens,
|
||||
answer_tokens=0,
|
||||
answer_unit_price=Decimal(),
|
||||
answer_price_unit=Decimal(),
|
||||
provider_response_latency=0,
|
||||
total_price=Decimal(),
|
||||
currency="USD",
|
||||
)
|
||||
|
||||
db.session.add(tool_model_invoke)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
except InvokeRateLimitError as e:
|
||||
raise InvokeModelError(f"Invoke rate limit error: {e}")
|
||||
except InvokeBadRequestError as e:
|
||||
raise InvokeModelError(f"Invoke bad request error: {e}")
|
||||
except InvokeConnectionError as e:
|
||||
raise InvokeModelError(f"Invoke connection error: {e}")
|
||||
except InvokeAuthorizationError as e:
|
||||
raise InvokeModelError("Invoke authorization error")
|
||||
except InvokeServerUnavailableError as e:
|
||||
raise InvokeModelError(f"Invoke server unavailable error: {e}")
|
||||
except Exception as e:
|
||||
raise InvokeModelError(f"Invoke error: {e}")
|
||||
|
||||
# update tool model invoke
|
||||
tool_model_invoke.model_response = str(response.message.content)
|
||||
if response.usage:
|
||||
tool_model_invoke.answer_tokens = response.usage.completion_tokens
|
||||
tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
|
||||
tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
|
||||
tool_model_invoke.provider_response_latency = response.usage.latency
|
||||
tool_model_invoke.total_price = response.usage.total_price
|
||||
tool_model_invoke.currency = response.usage.currency
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return response
|
||||
453
dify/api/core/tools/utils/parser.py
Normal file
453
dify/api/core/tools/utils/parser.py
Normal file
@@ -0,0 +1,453 @@
|
||||
import re
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from flask import request
|
||||
from yaml import YAMLError, safe_load
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
# set description to extra_info
|
||||
extra_info["description"] = openapi["info"].get("description", "")
|
||||
|
||||
if len(openapi["servers"]) == 0:
|
||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||
|
||||
server_url = openapi["servers"][0]["url"]
|
||||
request_env = request.headers.get("X-Request-Env")
|
||||
if request_env:
|
||||
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
|
||||
server_url = matched_servers[0] if matched_servers else server_url
|
||||
|
||||
# list all interfaces
|
||||
interfaces = []
|
||||
for path, path_item in openapi["paths"].items():
|
||||
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
|
||||
for method in methods:
|
||||
if method in path_item:
|
||||
interfaces.append(
|
||||
{
|
||||
"path": path,
|
||||
"method": method,
|
||||
"operation": path_item[method],
|
||||
}
|
||||
)
|
||||
|
||||
# get all parameters
|
||||
bundles = []
|
||||
for interface in interfaces:
|
||||
# convert parameters
|
||||
parameters = []
|
||||
if "parameters" in interface["operation"]:
|
||||
for i, parameter in enumerate(interface["operation"]["parameters"]):
|
||||
if "$ref" in parameter:
|
||||
root = openapi
|
||||
reference = parameter["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
interface["operation"]["parameters"][i] = root
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
# Handle complex type defaults that are not supported by PluginParameter
|
||||
default_value = None
|
||||
if "schema" in parameter and "default" in parameter["schema"]:
|
||||
default_value = ApiBasedToolSchemaParser._sanitize_default_value(parameter["schema"]["default"])
|
||||
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter["name"],
|
||||
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=parameter.get("required", False),
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=parameter.get("description"),
|
||||
default=default_value,
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
|
||||
if typ:
|
||||
tool_parameter.type = typ
|
||||
|
||||
parameters.append(tool_parameter)
|
||||
# create tool bundle
|
||||
# check if there is a request body
|
||||
if "requestBody" in interface["operation"]:
|
||||
request_body = interface["operation"]["requestBody"]
|
||||
if "content" in request_body:
|
||||
for content_type, content in request_body["content"].items():
|
||||
# if there is a reference, get the reference and overwrite the content
|
||||
if "schema" not in content:
|
||||
continue
|
||||
|
||||
if "$ref" in content["schema"]:
|
||||
# get the reference
|
||||
root = openapi
|
||||
reference = content["schema"]["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
# overwrite the content
|
||||
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
||||
|
||||
# handle allOf reference in schema properties
|
||||
for prop_dict in root.get("properties", {}).values():
|
||||
for item in prop_dict.get("allOf", []):
|
||||
if "$ref" in item:
|
||||
ref_schema = openapi
|
||||
reference = item["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
ref_schema = ref_schema[ref]
|
||||
else:
|
||||
ref_schema = item
|
||||
for key, value in ref_schema.items():
|
||||
if isinstance(value, list):
|
||||
if key not in prop_dict:
|
||||
prop_dict[key] = []
|
||||
# extends list field
|
||||
if isinstance(prop_dict[key], list):
|
||||
prop_dict[key].extend(value)
|
||||
elif key not in prop_dict:
|
||||
# add new field
|
||||
prop_dict[key] = value
|
||||
if "allOf" in prop_dict:
|
||||
del prop_dict["allOf"]
|
||||
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
# Handle complex type defaults that are not supported by PluginParameter
|
||||
default_value = ApiBasedToolSchemaParser._sanitize_default_value(
|
||||
property.get("default", None)
|
||||
)
|
||||
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=default_value,
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
|
||||
parameters.append(tool)
|
||||
|
||||
# check if parameters is duplicated
|
||||
parameters_count = {}
|
||||
for parameter in parameters:
|
||||
if parameter.name not in parameters_count:
|
||||
parameters_count[parameter.name] = 0
|
||||
parameters_count[parameter.name] += 1
|
||||
for name, count in parameters_count.items():
|
||||
if count > 1:
|
||||
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
|
||||
|
||||
# check if there is a operation id, use $path_$method as operation id if not
|
||||
if "operationId" not in interface["operation"]:
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = interface["path"]
|
||||
if interface["path"].startswith("/"):
|
||||
path = interface["path"][1:]
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
|
||||
if not path:
|
||||
path = "<root>"
|
||||
|
||||
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||
|
||||
bundles.append(
|
||||
ApiToolBundle(
|
||||
server_url=server_url + interface["path"],
|
||||
method=interface["method"],
|
||||
summary=interface["operation"]["description"]
|
||||
if "description" in interface["operation"]
|
||||
else interface["operation"].get("summary", None),
|
||||
operation_id=interface["operation"]["operationId"],
|
||||
parameters=parameters,
|
||||
author="",
|
||||
icon=None,
|
||||
openapi=interface["operation"],
|
||||
)
|
||||
)
|
||||
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_default_value(value):
|
||||
"""
|
||||
Sanitize default values for PluginParameter compatibility.
|
||||
Complex types (list, dict) are converted to None to avoid validation errors.
|
||||
|
||||
Args:
|
||||
value: The default value from OpenAPI schema
|
||||
|
||||
Returns:
|
||||
None for complex types (list, dict), otherwise the original value
|
||||
"""
|
||||
if isinstance(value, (list, dict)):
|
||||
return None
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
|
||||
parameter = parameter or {}
|
||||
typ: str | None = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
if "type" in parameter:
|
||||
typ = parameter["type"]
|
||||
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||
typ = parameter["schema"]["type"]
|
||||
|
||||
if typ in {"integer", "number"}:
|
||||
return ToolParameter.ToolParameterType.NUMBER
|
||||
elif typ == "boolean":
|
||||
return ToolParameter.ToolParameterType.BOOLEAN
|
||||
elif typ == "string":
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
elif typ == "array":
|
||||
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
||||
if items and items.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILES
|
||||
else:
|
||||
# For regular arrays, return ARRAY type instead of None
|
||||
return ToolParameter.ToolParameterType.ARRAY
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
yaml: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
:param yaml: the yaml string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
openapi: dict = safe_load(yaml)
|
||||
if openapi is None:
|
||||
raise ToolApiSchemaError("Invalid openapi yaml.")
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(
|
||||
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> dict[str, Any]:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
||||
:param swagger: the swagger dict
|
||||
:return: the openapi dict
|
||||
"""
|
||||
# convert swagger to openapi
|
||||
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
|
||||
|
||||
servers = swagger.get("servers", [])
|
||||
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
converted_openapi: dict[str, Any] = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
"description": info.get("description", "Swagger"),
|
||||
"version": info.get("version", "1.0.0"),
|
||||
},
|
||||
"servers": swagger["servers"],
|
||||
"paths": {},
|
||||
"components": {"schemas": {}},
|
||||
}
|
||||
|
||||
# check paths
|
||||
if "paths" not in swagger or len(swagger["paths"]) == 0:
|
||||
raise ToolApiSchemaError("No paths found in the swagger yaml.")
|
||||
|
||||
# convert paths
|
||||
for path, path_item in swagger["paths"].items():
|
||||
converted_openapi["paths"][path] = {}
|
||||
for method, operation in path_item.items():
|
||||
if "operationId" not in operation:
|
||||
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||
|
||||
if ("summary" not in operation or len(operation["summary"]) == 0) and (
|
||||
"description" not in operation or len(operation["description"]) == 0
|
||||
):
|
||||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
converted_openapi["paths"][path][method] = {
|
||||
"operationId": operation["operationId"],
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": operation.get("parameters", []),
|
||||
"responses": operation.get("responses", {}),
|
||||
}
|
||||
|
||||
if "requestBody" in operation:
|
||||
converted_openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||
|
||||
# convert definitions
|
||||
if "definitions" in swagger:
|
||||
for name, definition in swagger["definitions"].items():
|
||||
converted_openapi["components"]["schemas"][name] = definition
|
||||
|
||||
return converted_openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
json: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
:param json: the json string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
try:
|
||||
openai_plugin = json_loads(json)
|
||||
api = openai_plugin["api"]
|
||||
api_url = api["url"]
|
||||
api_type = api["type"]
|
||||
except JSONDecodeError:
|
||||
raise ToolProviderNotFoundError("Invalid openai plugin json.")
|
||||
|
||||
if api_type != "openapi":
|
||||
raise ToolNotSupportedError("Only openapi is supported now.")
|
||||
|
||||
# get openapi yaml
|
||||
response = httpx.get(
|
||||
api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5
|
||||
)
|
||||
|
||||
try:
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
|
||||
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
|
||||
response.text, extra_info=extra_info, warning=warning
|
||||
)
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
:param content: the content
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: tools bundle, schema_type
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
content = content.strip()
|
||||
loaded_content = None
|
||||
json_error = None
|
||||
yaml_error = None
|
||||
|
||||
try:
|
||||
loaded_content = json_loads(content)
|
||||
except JSONDecodeError as e:
|
||||
json_error = e
|
||||
|
||||
if loaded_content is None:
|
||||
try:
|
||||
loaded_content = safe_load(content)
|
||||
except YAMLError as e:
|
||||
yaml_error = e
|
||||
if loaded_content is None:
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
|
||||
f" yaml error: {str(yaml_error)}"
|
||||
)
|
||||
|
||||
swagger_error = None
|
||||
openapi_error = None
|
||||
openapi_plugin_error = None
|
||||
schema_type = None
|
||||
|
||||
try:
|
||||
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
return openapi, schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
openapi_error = e
|
||||
|
||||
# openai parse error, fallback to swagger
|
||||
try:
|
||||
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.SWAGGER
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
converted_swagger, extra_info=extra_info, warning=warning
|
||||
), schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
swagger_error = e
|
||||
|
||||
# swagger parse error, fallback to openai plugin
|
||||
try:
|
||||
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
||||
)
|
||||
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN
|
||||
except ToolNotSupportedError as e:
|
||||
# maybe it's not plugin at all
|
||||
openapi_plugin_error = e
|
||||
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
|
||||
f" openapi plugin error: {str(openapi_plugin_error)}"
|
||||
)
|
||||
187
dify/api/core/tools/utils/system_oauth_encryption.py
Normal file
187
dify/api/core/tools/utils/system_oauth_encryption.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Random import get_random_bytes
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthEncryptionError(Exception):
|
||||
"""OAuth encryption/decryption specific error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SystemOAuthEncrypter:
|
||||
"""
|
||||
A simple OAuth parameters encrypter using AES-CBC encryption.
|
||||
|
||||
This class provides methods to encrypt and decrypt OAuth parameters
|
||||
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||
"""
|
||||
|
||||
def __init__(self, secret_key: str | None = None):
|
||||
"""
|
||||
Initialize the OAuth encrypter.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Raises:
|
||||
ValueError: If SECRET_KEY is not configured or empty
|
||||
"""
|
||||
secret_key = secret_key or dify_config.SECRET_KEY or ""
|
||||
|
||||
# Generate a fixed 256-bit key using SHA-256
|
||||
self.key = hashlib.sha256(secret_key.encode()).digest()
|
||||
|
||||
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If encryption fails
|
||||
ValueError: If oauth_params is invalid
|
||||
"""
|
||||
|
||||
try:
|
||||
# Generate random IV (16 bytes)
|
||||
iv = get_random_bytes(16)
|
||||
|
||||
# Create AES cipher (CBC mode)
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Encrypt data
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
|
||||
# Combine IV and encrypted data
|
||||
combined = iv + encrypted_data
|
||||
|
||||
# Return base64 encoded string
|
||||
return base64.b64encode(combined).decode()
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
|
||||
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If decryption fails
|
||||
ValueError: If encrypted_data is invalid
|
||||
"""
|
||||
if not isinstance(encrypted_data, str):
|
||||
raise ValueError("encrypted_data must be a string")
|
||||
|
||||
if not encrypted_data:
|
||||
raise ValueError("encrypted_data cannot be empty")
|
||||
|
||||
try:
|
||||
# Base64 decode
|
||||
combined = base64.b64decode(encrypted_data)
|
||||
|
||||
# Check minimum length (IV + at least one AES block)
|
||||
if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
|
||||
raise ValueError("Invalid encrypted data format")
|
||||
|
||||
# Separate IV and encrypted data
|
||||
iv = combined[:16]
|
||||
encrypted_data_bytes = combined[16:]
|
||||
|
||||
# Create AES cipher
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Decrypt data
|
||||
decrypted_data = cipher.decrypt(encrypted_data_bytes)
|
||||
unpadded_data = unpad(decrypted_data, AES.block_size)
|
||||
|
||||
# Parse JSON
|
||||
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
|
||||
if not isinstance(oauth_params, dict):
|
||||
raise ValueError("Decrypted data is not a valid dictionary")
|
||||
|
||||
return oauth_params
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
|
||||
|
||||
# Factory function for creating encrypter instances
|
||||
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
|
||||
"""
|
||||
Create an OAuth encrypter instance.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
"""
|
||||
return SystemOAuthEncrypter(secret_key=secret_key)
|
||||
|
||||
|
||||
# Global encrypter instance (for backward compatibility)
|
||||
_oauth_encrypter: SystemOAuthEncrypter | None = None
|
||||
|
||||
|
||||
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
|
||||
"""
|
||||
Get the global OAuth encrypter instance.
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
"""
|
||||
global _oauth_encrypter
|
||||
if _oauth_encrypter is None:
|
||||
_oauth_encrypter = SystemOAuthEncrypter()
|
||||
return _oauth_encrypter
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
"""
|
||||
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
|
||||
|
||||
|
||||
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
"""
|
||||
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
|
||||
17
dify/api/core/tools/utils/text_processing_utils.py
Normal file
17
dify/api/core/tools/utils/text_processing_utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import re
|
||||
|
||||
|
||||
def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
Remove leading punctuation or symbols from the given text.
|
||||
|
||||
Args:
|
||||
text (str): The input text to process.
|
||||
|
||||
Returns:
|
||||
str: The text with leading punctuation or symbols removed.
|
||||
"""
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||
return re.sub(pattern, "", text)
|
||||
11
dify/api/core/tools/utils/uuid_utils.py
Normal file
11
dify/api/core/tools/utils/uuid_utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import uuid
|
||||
|
||||
|
||||
def is_valid_uuid(uuid_str: str | None) -> bool:
|
||||
if uuid_str is None or len(uuid_str) == 0:
|
||||
return False
|
||||
try:
|
||||
uuid.UUID(uuid_str)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
128
dify/api/core/tools/utils/web_reader_tool.py
Normal file
128
dify/api/core/tools/utils/web_reader_tool.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import mimetypes
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import cloudscraper
|
||||
from readabilipy import simple_json_from_html_string
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor import extract_processor
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
AUTHOR: {author}
|
||||
TEXT:
|
||||
|
||||
{text}
|
||||
"""
|
||||
|
||||
|
||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
||||
return text[cursor : cursor + max_length]
|
||||
|
||||
|
||||
def get_url(url: str, user_agent: str | None = None) -> str:
|
||||
"""Fetch URL and return the contents as a string."""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
||||
" Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
if user_agent:
|
||||
headers["User-Agent"] = user_agent
|
||||
|
||||
main_content_type = None
|
||||
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||
response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
|
||||
|
||||
if response.status_code == 200:
|
||||
# check content-type
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if content_type:
|
||||
main_content_type = response.headers.get("Content-Type").split(";")[0].strip()
|
||||
else:
|
||||
content_disposition = response.headers.get("Content-Disposition", "")
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
extension = re.search(r"\.(\w+)$", filename)
|
||||
if extension:
|
||||
main_content_type = mimetypes.guess_type(filename)[0]
|
||||
|
||||
if main_content_type not in supported_content_types:
|
||||
return f"Unsupported content-type [{main_content_type}] of URL."
|
||||
|
||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
|
||||
|
||||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
elif response.status_code == 403:
|
||||
scraper = cloudscraper.create_scraper()
|
||||
scraper.perform_request = ssrf_proxy.make_request
|
||||
response = scraper.get(url, headers=headers, timeout=(120, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
return f"URL returned status code {response.status_code}."
|
||||
|
||||
# Detect encoding using chardet
|
||||
detected_encoding = chardet.detect(response.content)
|
||||
encoding = detected_encoding["encoding"]
|
||||
if encoding:
|
||||
try:
|
||||
content = response.content.decode(encoding)
|
||||
except (UnicodeDecodeError, TypeError):
|
||||
content = response.text
|
||||
else:
|
||||
content = response.text
|
||||
|
||||
article = extract_using_readabilipy(content)
|
||||
|
||||
if not article.text:
|
||||
return ""
|
||||
|
||||
res = FULL_TEMPLATE.format(
|
||||
title=article.title,
|
||||
author=article.author,
|
||||
text=article.text,
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@dataclass
|
||||
class Article:
|
||||
title: str
|
||||
author: str
|
||||
text: Sequence[dict]
|
||||
|
||||
|
||||
def extract_using_readabilipy(html: str):
|
||||
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
|
||||
article = Article(
|
||||
title=json_article.get("title") or "",
|
||||
author=json_article.get("byline") or "",
|
||||
text=json_article.get("plain_text") or [],
|
||||
)
|
||||
|
||||
return article
|
||||
|
||||
|
||||
def get_image_upload_file_ids(content):
|
||||
pattern = r"!\[image\]\((https?://.*?(file-preview|image-preview))\)"
|
||||
matches = re.findall(pattern, content)
|
||||
image_upload_file_ids = []
|
||||
for match in matches:
|
||||
if match[1] == "file-preview":
|
||||
content_pattern = r"files/([^/]+)/file-preview"
|
||||
else:
|
||||
content_pattern = r"files/([^/]+)/image-preview"
|
||||
content_match = re.search(content_pattern, match[0])
|
||||
if content_match:
|
||||
image_upload_file_id = content_match.group(1)
|
||||
image_upload_file_ids.append(image_upload_file_id)
|
||||
return image_upload_file_ids
|
||||
43
dify/api/core/tools/utils/workflow_configuration_sync.py
Normal file
43
dify/api/core/tools/utils/workflow_configuration_sync.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
):
|
||||
"""
|
||||
check is synced
|
||||
|
||||
raise ValueError if not synced
|
||||
"""
|
||||
variable_names = [variable.variable for variable in variables]
|
||||
|
||||
if len(tool_configurations) != len(variables):
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
|
||||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
33
dify/api/core/tools/utils/yaml_utils.py
Normal file
33
dify/api/core/tools/utils/yaml_utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load_yaml_file(*, file_path: str):
|
||||
if not file_path or not Path(file_path).exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, encoding="utf-8") as yaml_file:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content
|
||||
except Exception as e:
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def load_yaml_file_cached(file_path: str) -> Any:
|
||||
"""
|
||||
Cached version of load_yaml_file for static configuration files.
|
||||
Only use for files that don't change during runtime (e.g., position files)
|
||||
|
||||
:param file_path: the path of the YAML file
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
return _load_yaml_file(file_path=file_path)
|
||||
Reference in New Issue
Block a user