This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

@@ -0,0 +1,233 @@
from abc import ABC, abstractmethod
from collections.abc import Generator
from copy import deepcopy
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from models.model import File
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import (
ToolEntity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
class Tool(ABC):
"""
The base class of a tool
"""
def __init__(self, entity: ToolEntity, runtime: ToolRuntime):
self.entity = entity
self.runtime = runtime
def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
"""
fork a new tool with metadata
:return: the new tool
"""
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
)
@abstractmethod
def tool_provider_type(self) -> ToolProviderType:
"""
get the tool provider type
:return: the tool provider type
"""
def invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage]:
if self.runtime and self.runtime.runtime_parameters:
tool_parameters.update(self.runtime.runtime_parameters)
# try parse tool parameters into the correct type
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
result = self._invoke(
user_id=user_id,
tool_parameters=tool_parameters,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
if isinstance(result, ToolInvokeMessage):
def single_generator() -> Generator[ToolInvokeMessage, None, None]:
yield result
return single_generator()
elif isinstance(result, list):
def generator() -> Generator[ToolInvokeMessage, None, None]:
yield from result
return generator()
else:
return result
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
"""
Transform tool parameters type
"""
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
result = deepcopy(tool_parameters)
for parameter in self.entity.parameters or []:
if parameter.name in tool_parameters:
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
return result
@abstractmethod
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
pass
def get_runtime_parameters(
self,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> list[ToolParameter]:
"""
get the runtime parameters
interface for developer to dynamic change the parameters of a tool depends on the variables pool
:return: the runtime parameters
"""
return self.entity.parameters
def get_merged_runtime_parameters(
self,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> list[ToolParameter]:
"""
get merged runtime parameters
:return: merged runtime parameters
"""
parameters = self.entity.parameters
parameters = parameters.copy()
user_parameters = self.get_runtime_parameters() or []
user_parameters = user_parameters.copy()
# override parameters
for parameter in user_parameters:
# check if parameter in tool parameters
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
break
else:
# add new parameter
parameters.append(parameter)
return parameters
def create_image_message(
self,
image: str,
) -> ToolInvokeMessage:
"""
create an image message
:param image: the url of the image
:return: the image message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
)
def create_file_message(self, file: "File") -> ToolInvokeMessage:
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.FILE,
message=ToolInvokeMessage.FileMessage(),
meta={"file": file},
)
def create_link_message(self, link: str) -> ToolInvokeMessage:
"""
create a link message
:param link: the url of the link
:return: the link message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link)
)
def create_text_message(self, text: str) -> ToolInvokeMessage:
"""
create a text message
:param text: the text
:return: the text message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(text=text),
)
def create_blob_message(self, blob: bytes, meta: dict | None = None) -> ToolInvokeMessage:
"""
create a blob message
:param blob: the blob
:param meta: the meta info of blob object
:return: the blob message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=blob),
meta=meta,
)
def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage:
"""
create a json message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.JSON,
message=ToolInvokeMessage.JsonMessage(json_object=object, suppress_output=suppress_output),
)
def create_variable_message(
self, variable_name: str, variable_value: Any, stream: bool = False
) -> ToolInvokeMessage:
"""
create a variable message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.VARIABLE,
message=ToolInvokeMessage.VariableMessage(
variable_name=variable_name, variable_value=variable_value, stream=stream
),
)

View File

@@ -0,0 +1,107 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any
from core.entities.provider_entities import ProviderConfig
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolProviderEntity,
ToolProviderType,
)
from core.tools.errors import ToolProviderCredentialValidationError
class ToolProviderController(ABC):
def __init__(self, entity: ToolProviderEntity):
self.entity = entity
def get_credentials_schema(self) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:return: the credentials schema
"""
return deepcopy(self.entity.credentials_schema)
@abstractmethod
def get_tool(self, tool_name: str) -> Tool:
"""
returns a tool that the provider can provide
:return: tool
"""
pass
@property
def provider_type(self) -> ToolProviderType:
"""
returns the type of the provider
:return: type of the provider
"""
return ToolProviderType.BUILT_IN
def validate_credentials_format(self, credentials: dict[str, Any]):
"""
validate the format of the credentials of the provider and set the default value if needed
:param credentials: the credentials of the tool
"""
credentials_schema = dict[str, ProviderConfig]()
if credentials_schema is None:
return
for credential in self.entity.credentials_schema:
credentials_schema[credential.name] = credential
credentials_need_to_validate: dict[str, ProviderConfig] = {}
for credential_name in credentials_schema:
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
for credential_name in credentials:
if credential_name not in credentials_need_to_validate:
raise ToolProviderCredentialValidationError(
f"credential {credential_name} not found in provider {self.entity.identity.name}"
)
# check type
credential_schema = credentials_need_to_validate[credential_name]
if not credential_schema.required and credentials[credential_name] is None:
continue
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
elif credential_schema.type == ProviderConfig.Type.SELECT:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
options = credential_schema.options
if not isinstance(options, list):
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
if credentials[credential_name] not in [x.value for x in options]:
raise ToolProviderCredentialValidationError(
f"credential {credential_name} should be one of {options}"
)
credentials_need_to_validate.pop(credential_name)
for credential_name in credentials_need_to_validate:
credential_schema = credentials_need_to_validate[credential_name]
if credential_schema.required:
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
# the credential is not set currently, set the default value if needed
if credential_schema.default is not None:
default_value = credential_schema.default
# parse default value into the correct type
if credential_schema.type in {
ProviderConfig.Type.SECRET_INPUT,
ProviderConfig.Type.TEXT_INPUT,
ProviderConfig.Type.SELECT,
}:
default_value = str(default_value)
credentials[credential_name] = default_value

View File

@@ -0,0 +1,37 @@
from typing import Any
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.entities.tool_entities import ToolInvokeFrom
class ToolRuntime(BaseModel):
"""
Meta data of a tool call processing
"""
tenant_id: str
tool_id: str | None = None
invoke_from: InvokeFrom | None = None
tool_invoke_from: ToolInvokeFrom | None = None
credentials: dict[str, Any] = Field(default_factory=dict)
credential_type: CredentialType = Field(default=CredentialType.API_KEY)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
class FakeToolRuntime(ToolRuntime):
"""
Fake tool runtime for testing
"""
def __init__(self):
super().__init__(
tenant_id="fake_tenant_id",
tool_id="fake_tool_id",
invoke_from=InvokeFrom.DEBUGGER,
tool_invoke_from=ToolInvokeFrom.AGENT,
credentials={},
runtime_parameters={},
)

View File

View File

@@ -0,0 +1,4 @@
- audio
- code
- time
- webscraper

View File

@@ -0,0 +1,221 @@
from abc import abstractmethod
from os import listdir, path
from typing import Any
from core.entities.provider_entities import ProviderConfig
from core.helper.module_import_helper import load_single_subclass_from_source
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import (
OAuthSchema,
ToolEntity,
ToolProviderEntity,
ToolProviderType,
)
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolProviderNotFoundError,
)
from core.tools.utils.yaml_utils import load_yaml_file_cached
class BuiltinToolProviderController(ToolProviderController):
tools: list[BuiltinTool]
def __init__(self, **data: Any):
self.tools = []
# load provider yaml
provider = self.__class__.__module__.split(".")[-1]
yaml_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, f"{provider}.yaml")
try:
provider_yaml = load_yaml_file_cached(yaml_path)
except Exception as e:
raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}")
if "credentials_for_provider" in provider_yaml and provider_yaml["credentials_for_provider"] is not None:
# set credentials name
for credential_name in provider_yaml["credentials_for_provider"]:
provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name
credentials_schema = []
for credential in provider_yaml.get("credentials_for_provider", {}):
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
credentials_schema.append(credential_dict)
oauth_schema = None
if provider_yaml.get("oauth_schema", None) is not None:
oauth_schema = OAuthSchema(
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
)
super().__init__(
entity=ToolProviderEntity(
identity=provider_yaml["identity"],
credentials_schema=credentials_schema,
oauth_schema=oauth_schema,
),
)
self._load_tools()
def _load_tools(self):
provider = self.entity.identity.name
tool_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, "tools")
# get all the yaml files in the tool path
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
tools = []
for tool_file in tool_files:
# get tool name
tool_name = tool_file.split(".")[0]
tool = load_yaml_file_cached(path.join(tool_path, tool_file))
# get tool class, import the module
assistant_tool_class: type = load_single_subclass_from_source(
module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}",
script_path=path.join(
path.dirname(path.realpath(__file__)),
"builtin_tool",
"providers",
provider,
"tools",
f"{tool_name}.py",
),
parent_type=BuiltinTool,
)
tool["identity"]["provider"] = provider
tools.append(
assistant_tool_class(
provider=provider,
entity=ToolEntity.model_validate(tool),
runtime=ToolRuntime(tenant_id=""),
)
)
self.tools = tools
def _get_builtin_tools(self) -> list[BuiltinTool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
return self.tools
def get_credentials_schema(self) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:return: the credentials schema
"""
return self.get_credentials_schema_by_type(CredentialType.API_KEY)
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:param credential_type: the type of the credential
:return: the credentials schema of the provider
"""
if credential_type == CredentialType.OAUTH2.value:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
if credential_type == CredentialType.API_KEY:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
raise ValueError(f"Invalid credential type: {credential_type}")
def get_oauth_client_schema(self) -> list[ProviderConfig]:
"""
returns the oauth client schema of the provider
:return: the oauth client schema
"""
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
def get_supported_credential_types(self) -> list[CredentialType]:
"""
returns the credential support type of the provider
"""
types = []
if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0:
types.append(CredentialType.API_KEY)
if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0:
types.append(CredentialType.OAUTH2)
return types
def get_tools(self) -> list[BuiltinTool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
return self._get_builtin_tools()
def get_tool(self, tool_name: str) -> BuiltinTool | None: # type: ignore
"""
returns the tool that the provider can provide
"""
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
@property
def need_credentials(self) -> bool:
"""
returns whether the provider needs credentials
:return: whether the provider needs credentials
"""
return (
self.entity.credentials_schema is not None
and len(self.entity.credentials_schema) != 0
or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0)
)
@property
def provider_type(self) -> ToolProviderType:
"""
returns the type of the provider
:return: type of the provider
"""
return ToolProviderType.BUILT_IN
@property
def tool_labels(self) -> list[str]:
"""
returns the labels of the provider
:return: labels of the provider
"""
label_enums = self._get_tool_labels()
return [default_tool_label_dict[label].name for label in label_enums]
def _get_tool_labels(self) -> list[ToolLabelEnum]:
"""
returns the labels of the provider
"""
return self.entity.identity.tags or []
def validate_credentials(self, user_id: str, credentials: dict[str, Any]):
"""
validate the credentials of the provider
:param user_id: use id
:param credentials: the credentials of the tool
"""
# validate credentials format
self.validate_credentials_format(credentials)
# validate credentials
self._validate_credentials(user_id, credentials)
@abstractmethod
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
"""
validate the credentials of the provider
:param user_id: use id
:param credentials: the credentials of the tool
"""
pass

View File

@@ -0,0 +1,20 @@
import os.path
from core.helper.position_helper import get_tool_position_map, sort_by_position_map
from core.tools.entities.api_entities import ToolProviderApiEntity
class BuiltinToolProviderSort:
_position: dict[str, int] = {}
@classmethod
def sort(cls, providers: list[ToolProviderApiEntity]) -> list[ToolProviderApiEntity]:
if not cls._position:
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), ".."))
def name_func(provider: ToolProviderApiEntity) -> str:
return provider.name
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
return sorted_providers

View File

@@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" width="200" height="200" viewBox="0 0 200 200" fill="none">
<path d="M167.358 102.395C167.358 117.174 157.246 129.18 144.61 131.027H137.861C125.225 129.18 115.113 117.174 115.113 102.395H100.792C100.792 123.637 115.118 142.106 133.653 145.801V164.276H147.139V145.801C165.674 142.106 180 124.558 180 102.4H167.358V102.395ZM154.717 62.677C154.717 53.4397 147.979 46.9765 140.396 46.9765C138.523 46.9446 136.663 47.3273 134.924 48.1024C133.185 48.8775 131.603 50.0294 130.27 51.4909C128.936 52.9524 127.878 54.6943 127.157 56.6148C126.436 58.5354 126.066 60.5962 126.07 62.677V78.3775H154.717V70.4478V62.677ZM126.07 102.395C126.07 111.632 132.813 118.095 140.396 118.095C142.269 118.127 144.13 117.744 145.868 116.969C147.607 116.194 149.189 115.042 150.523 113.581C151.856 112.119 152.914 110.377 153.635 108.457C154.356 106.536 154.726 104.475 154.722 102.395V86.694H126.07V102.395ZM92.1297 45.8938L70.4796 21.7595L69.4235 20.5865L59.604 20L68.3674 20.5865L67.3113 21.7654L64.1429 25.2961L63.6149 25.8826L64.1429 27.0614L66.2552 29.4133L77.8723 42.3631H54.1099C35.1 43.5361 20.3146 61.1896 20.3146 81.7874V83.5527H28.2354V81.7932C28.2354 65.8992 39.8525 52.3628 54.1099 51.1899H77.8723L66.2552 64.1338L64.671 65.8992L64.1429 67.0722L63.6149 67.6645L64.1429 68.251L68.3674 72.9606L68.8954 73.5471L69.4235 72.9606L74.1759 67.6645L92.1297 47.6591L92.6578 47.0727L92.1297 45.8938ZM20 95.8496V118.213H30.033V107.034H50.099V168.821H40.066V180H70.165V168.821H60.132V107.034H80.198V118.213H90.231V95.8496H20Z" fill="#FF0099"/>
</svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

View File

@@ -0,0 +1,8 @@
from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController
class AudioToolProvider(BuiltinToolProviderController):
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
pass

View File

@@ -0,0 +1,11 @@
identity:
author: hjlarry
name: audio
label:
en_US: Audio
description:
en_US: A tool for tts and asr.
zh_Hans: 一个用于文本转语音和语音转文本的工具。
icon: icon.svg
tags:
- utilities

View File

@@ -0,0 +1,84 @@
import io
from collections.abc import Generator
from typing import Any
from core.file.enums import FileType
from core.file.file_manager import download
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.parameters import PluginParameterOption
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from services.model_provider_service import ModelProviderService
class ASRTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
file = tool_parameters.get("audio_file")
if file.type != FileType.AUDIO: # type: ignore
yield self.create_text_message("not a valid audio file")
return
audio_binary = io.BytesIO(download(file)) # type: ignore
audio_binary.name = "temp.mp3"
provider, model = tool_parameters.get("model").split("#") # type: ignore
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id,
provider=provider,
model_type=ModelType.SPEECH2TEXT,
model=model,
)
text = model_instance.invoke_speech2text(
file=audio_binary,
user=user_id,
)
yield self.create_text_message(text)
def get_available_models(self) -> list[tuple[str, str]]:
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(
tenant_id=self.runtime.tenant_id, model_type="speech2text"
)
items = []
for provider_model in models:
provider = provider_model.provider
for model in provider_model.models:
items.append((provider, model.model))
return items
def get_runtime_parameters(
self,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> list[ToolParameter]:
parameters = []
options = []
for provider, model in self.get_available_models():
option = PluginParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
options.append(option)
parameters.append(
ToolParameter(
name="model",
label=I18nObject(en_US="Model", zh_Hans="Model"),
human_description=I18nObject(
en_US="All available ASR models. You can config model in the Model Provider of Settings.",
zh_Hans="所有可用的 ASR 模型。你可以在设置中的模型供应商里配置。",
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=True,
options=options,
)
)
return parameters

View File

@@ -0,0 +1,22 @@
identity:
name: asr
author: hjlarry
label:
en_US: Speech To Text
description:
human:
en_US: Convert audio file to text.
zh_Hans: 将音频文件转换为文本。
llm: Convert audio file to text.
parameters:
- name: audio_file
type: file
required: true
label:
en_US: Audio File
zh_Hans: 音频文件
human_description:
en_US: The audio file to be converted.
zh_Hans: 要转换的音频文件。
llm_description: The audio file to be converted.
form: llm

View File

@@ -0,0 +1,116 @@
import io
from collections.abc import Generator
from typing import Any
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.plugin.entities.parameters import PluginParameterOption
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from services.model_provider_service import ModelProviderService
class TTSTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
provider, model = tool_parameters.get("model").split("#") # type: ignore
voice = tool_parameters.get(f"voice#{provider}#{model}")
model_manager = ModelManager()
if not self.runtime:
raise ValueError("Runtime is required")
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id or "",
provider=provider,
model_type=ModelType.TTS,
model=model,
)
if not voice:
voices = model_instance.get_tts_voices()
if voices:
voice = voices[0].get("value")
if not voice:
raise ValueError("Sorry, no voice available.")
else:
raise ValueError("Sorry, no voice available.")
tts = model_instance.invoke_tts(
content_text=tool_parameters.get("text"), # type: ignore
user=user_id,
tenant_id=self.runtime.tenant_id,
voice=voice,
)
buffer = io.BytesIO()
for chunk in tts:
buffer.write(chunk)
wav_bytes = buffer.getvalue()
yield self.create_text_message("Audio generated successfully")
yield self.create_blob_message(
blob=wav_bytes,
meta={"mime_type": "audio/x-wav"},
)
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
if not self.runtime:
raise ValueError("Runtime is required")
model_provider_service = ModelProviderService()
tid: str = self.runtime.tenant_id or ""
models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts")
items = []
for provider_model in models:
provider = provider_model.provider
for model in provider_model.models:
voices = model.model_properties.get(ModelPropertyKey.VOICES, [])
items.append((provider, model.model, voices))
return items
def get_runtime_parameters(
self,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> list[ToolParameter]:
parameters = []
options = []
for provider, model, voices in self.get_available_models():
option = PluginParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
options.append(option)
parameters.append(
ToolParameter(
name=f"voice#{provider}#{model}",
label=I18nObject(en_US=f"Voice of {model}({provider})"),
human_description=I18nObject(en_US=f"Select a voice for {model} model"),
placeholder=I18nObject(en_US="Select a voice"),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
options=[
PluginParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name")))
for voice in voices
],
)
)
parameters.insert(
0,
ToolParameter(
name="model",
label=I18nObject(en_US="Model", zh_Hans="Model"),
human_description=I18nObject(
en_US="All available TTS models. You can config model in the Model Provider of Settings.",
zh_Hans="所有可用的 TTS 模型。你可以在设置中的模型供应商里配置。",
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=True,
placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"),
options=options,
),
)
return parameters

View File

@@ -0,0 +1,22 @@
identity:
name: tts
author: hjlarry
label:
en_US: Text To Speech
description:
human:
en_US: Convert text to audio file.
zh_Hans: 将文本转换为音频文件。
llm: Convert text to audio file.
parameters:
- name: text
type: string
required: true
label:
en_US: Text
zh_Hans: 文本
human_description:
en_US: The text to be converted.
zh_Hans: 要转换的文本。
llm_description: The text to be converted.
form: llm

View File

@@ -0,0 +1 @@
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg" class="w-3.5 h-3.5" data-icon="Code" aria-hidden="true"><g id="icons/code"><path id="Vector (Stroke)" fill-rule="evenodd" clip-rule="evenodd" d="M8.32593 1.69675C8.67754 1.78466 8.89132 2.14096 8.80342 2.49257L6.47009 11.8259C6.38218 12.1775 6.02588 12.3913 5.67427 12.3034C5.32265 12.2155 5.10887 11.8592 5.19678 11.5076L7.53011 2.17424C7.61801 1.82263 7.97431 1.60885 8.32593 1.69675ZM3.96414 4.20273C4.22042 4.45901 4.22042 4.87453 3.96413 5.13081L2.45578 6.63914C2.45577 6.63915 2.45578 6.63914 2.45578 6.63914C2.25645 6.83851 2.25643 7.16168 2.45575 7.36103C2.45574 7.36103 2.45576 7.36104 2.45575 7.36103L3.96413 8.86936C4.22041 9.12564 4.22042 9.54115 3.96414 9.79744C3.70787 10.0537 3.29235 10.0537 3.03607 9.79745L1.52769 8.28913C0.815811 7.57721 0.815803 6.42302 1.52766 5.7111L3.03606 4.20272C3.29234 3.94644 3.70786 3.94644 3.96414 4.20273ZM10.0361 4.20273C10.2923 3.94644 10.7078 3.94644 10.9641 4.20272L12.4725 5.71108C13.1843 6.423 13.1844 7.57717 12.4725 8.28909L10.9641 9.79745C10.7078 10.0537 10.2923 10.0537 10.036 9.79744C9.77977 9.54115 9.77978 9.12564 10.0361 8.86936L11.5444 7.36107C11.7437 7.16172 11.7438 6.83854 11.5444 6.63917C11.5444 6.63915 11.5445 6.63918 11.5444 6.63917L10.0361 5.13081C9.77978 4.87453 9.77978 4.45901 10.0361 4.20273Z" fill="#2e90fa"></path></g></svg>

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -0,0 +1,8 @@
from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController
class CodeToolProvider(BuiltinToolProviderController):
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
pass

View File

@@ -0,0 +1,14 @@
identity:
author: Dify
name: code
label:
en_US: Code Interpreter
zh_Hans: 代码解释器
pt_BR: Interpretador de Código
description:
en_US: Run a piece of code and get the result back.
zh_Hans: 运行一段代码并返回结果。
pt_BR: Execute um trecho de código e obtenha o resultado de volta.
icon: icon.svg
tags:
- productivity

View File

@@ -0,0 +1,33 @@
from collections.abc import Generator
from typing import Any
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolInvokeError
class SimpleCode(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke simple code
"""
language = tool_parameters.get("language", CodeLanguage.PYTHON3)
code = tool_parameters.get("code", "")
if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}:
raise ValueError(f"Only python3 and javascript are supported, not {language}")
try:
result = CodeExecutor.execute_code(language, "", code)
yield self.create_text_message(result)
except Exception as e:
raise ToolInvokeError(str(e))

View File

@@ -0,0 +1,51 @@
identity:
name: simple_code
author: Dify
label:
en_US: Code Interpreter
zh_Hans: 代码解释器
pt_BR: Interpretador de Código
description:
human:
en_US: Run code and get the result back. When you're using a lower quality model, please make sure there are some tips help LLM to understand how to write the code.
zh_Hans: 运行一段代码并返回结果。当您使用较低质量的模型时,请确保有一些提示帮助 LLM 理解如何编写代码。
pt_BR: Execute um trecho de código e obtenha o resultado de volta. quando você estiver usando um modelo de qualidade inferior, certifique-se de que existam algumas dicas para ajudar o LLM a entender como escrever o código.
llm: A tool for running code and getting the result back. Only native packages are allowed, network/IO operations are disabled. and you must use print() or console.log() to output the result or result will be empty.
parameters:
- name: language
type: string
required: true
label:
en_US: Language
zh_Hans: 语言
pt_BR: Idioma
human_description:
en_US: The programming language of the code
zh_Hans: 代码的编程语言
pt_BR: A linguagem de programação do código
llm_description: language of the code, only "python3" and "javascript" are supported
form: llm
options:
- value: python3
label:
en_US: Python3
zh_Hans: Python3
pt_BR: Python3
- value: javascript
label:
en_US: JavaScript
zh_Hans: JavaScript
pt_BR: JavaScript
- name: code
type: string
required: true
label:
en_US: Code
zh_Hans: 代码
pt_BR: Código
human_description:
en_US: The code to be executed
zh_Hans: 要执行的代码
pt_BR: O código a ser executado
llm_description: code to be executed, only native packages are allowed, network/IO operations are disabled.
form: llm

View File

@@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path fill-rule="evenodd" clip-rule="evenodd" d="M0.666992 8.00008C0.666992 3.94999 3.95024 0.666748 8.00033 0.666748C12.0504 0.666748 15.3337 3.94999 15.3337 8.00008C15.3337 12.0502 12.0504 15.3334 8.00033 15.3334C3.95024 15.3334 0.666992 12.0502 0.666992 8.00008ZM8.66699 4.00008C8.66699 3.63189 8.36852 3.33341 8.00033 3.33341C7.63213 3.33341 7.33366 3.63189 7.33366 4.00008V8.00008C7.33366 8.2526 7.47633 8.48344 7.70218 8.59637L10.3688 9.9297C10.6982 10.0944 11.0986 9.96088 11.2633 9.63156C11.4279 9.30224 11.2945 8.90179 10.9651 8.73713L8.66699 7.58806V4.00008Z" fill="#EC4A0A"/>
</svg>

After

Width:  |  Height:  |  Size: 691 B

View File

@@ -0,0 +1,8 @@
from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController
class WikiPediaProvider(BuiltinToolProviderController):
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
pass

View File

@@ -0,0 +1,14 @@
identity:
author: Dify
name: time
label:
en_US: CurrentTime
zh_Hans: 时间
pt_BR: CurrentTime
description:
en_US: A tool for getting the current time.
zh_Hans: 一个用于获取当前时间的工具。
pt_BR: A tool for getting the current time.
icon: icon.svg
tags:
- utilities

View File

@@ -0,0 +1,35 @@
from collections.abc import Generator
from datetime import UTC, datetime
from typing import Any
from pytz import timezone as pytz_timezone
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
class CurrentTimeTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke tools
"""
# get timezone
tz = tool_parameters.get("timezone", "UTC")
fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z"
if tz == "UTC":
yield self.create_text_message(f"{datetime.now(UTC).strftime(fm)}")
return
try:
tz = pytz_timezone(tz)
except Exception:
yield self.create_text_message(f"Invalid timezone: {tz}")
return
yield self.create_text_message(f"{datetime.now(tz).strftime(fm)}")

View File

@@ -0,0 +1,131 @@
identity:
name: current_time
author: Dify
label:
en_US: Current Time
zh_Hans: 获取当前时间
pt_BR: Current Time
description:
human:
en_US: A tool for getting the current time.
zh_Hans: 一个用于获取当前时间的工具。
pt_BR: A tool for getting the current time.
llm: A tool for getting the current time.
parameters:
- name: format
type: string
required: false
label:
en_US: Format
zh_Hans: 格式
pt_BR: Format
human_description:
en_US: Time format in strftime standard.
zh_Hans: strftime 标准的时间格式。
pt_BR: Time format in strftime standard.
form: form
default: "%Y-%m-%d %H:%M:%S"
- name: timezone
type: select
required: false
label:
en_US: Timezone
zh_Hans: 时区
pt_BR: Timezone
human_description:
en_US: Timezone
zh_Hans: 时区
pt_BR: Timezone
form: form
default: UTC
options:
- value: UTC
label:
en_US: UTC
zh_Hans: UTC
pt_BR: UTC
- value: America/New_York
label:
en_US: America/New_York
zh_Hans: 美洲/纽约
pt_BR: America/New_York
- value: America/Los_Angeles
label:
en_US: America/Los_Angeles
zh_Hans: 美洲/洛杉矶
pt_BR: America/Los_Angeles
- value: America/Chicago
label:
en_US: America/Chicago
zh_Hans: 美洲/芝加哥
pt_BR: America/Chicago
- value: America/Sao_Paulo
label:
en_US: America/Sao_Paulo
zh_Hans: 美洲/圣保罗
pt_BR: América/São Paulo
- value: Asia/Shanghai
label:
en_US: Asia/Shanghai
zh_Hans: 亚洲/上海
pt_BR: Asia/Shanghai
- value: Asia/Ho_Chi_Minh
label:
en_US: Asia/Ho_Chi_Minh
zh_Hans: 亚洲/胡志明市
pt_BR: Ásia/Ho Chi Minh
- value: Asia/Tokyo
label:
en_US: Asia/Tokyo
zh_Hans: 亚洲/东京
pt_BR: Asia/Tokyo
- value: Asia/Dubai
label:
en_US: Asia/Dubai
zh_Hans: 亚洲/迪拜
pt_BR: Asia/Dubai
- value: Asia/Kolkata
label:
en_US: Asia/Kolkata
zh_Hans: 亚洲/加尔各答
pt_BR: Asia/Kolkata
- value: Asia/Seoul
label:
en_US: Asia/Seoul
zh_Hans: 亚洲/首尔
pt_BR: Asia/Seoul
- value: Asia/Singapore
label:
en_US: Asia/Singapore
zh_Hans: 亚洲/新加坡
pt_BR: Asia/Singapore
- value: Europe/London
label:
en_US: Europe/London
zh_Hans: 欧洲/伦敦
pt_BR: Europe/London
- value: Europe/Berlin
label:
en_US: Europe/Berlin
zh_Hans: 欧洲/柏林
pt_BR: Europe/Berlin
- value: Europe/Moscow
label:
en_US: Europe/Moscow
zh_Hans: 欧洲/莫斯科
pt_BR: Europe/Moscow
- value: Australia/Sydney
label:
en_US: Australia/Sydney
zh_Hans: 澳大利亚/悉尼
pt_BR: Australia/Sydney
- value: Pacific/Auckland
label:
en_US: Pacific/Auckland
zh_Hans: 太平洋/奥克兰
pt_BR: Pacific/Auckland
- value: Africa/Cairo
label:
en_US: Africa/Cairo
zh_Hans: 非洲/开罗
pt_BR: Africa/Cairo

View File

@@ -0,0 +1,50 @@
from collections.abc import Generator
from datetime import datetime
from typing import Any
import pytz
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolInvokeError
class LocaltimeToTimestampTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
Convert localtime to timestamp
"""
localtime = tool_parameters.get("localtime")
timezone = tool_parameters.get("timezone", "Asia/Shanghai")
if not timezone:
timezone = None
time_format = "%Y-%m-%d %H:%M:%S"
timestamp = self.localtime_to_timestamp(localtime, time_format, timezone) # type: ignore
if not timestamp:
yield self.create_text_message(f"Invalid localtime: {localtime}")
return
yield self.create_text_message(f"{timestamp}")
# TODO: this method's type is messy
@staticmethod
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
try:
local_time = datetime.strptime(localtime, time_format)
if local_tz is None:
localtime = local_time.astimezone() # type: ignore
elif isinstance(local_tz, str):
local_tz = pytz.timezone(local_tz)
localtime = local_tz.localize(local_time) # type: ignore
timestamp = int(localtime.timestamp()) # type: ignore
return timestamp
except Exception as e:
raise ToolInvokeError(str(e))

View File

@@ -0,0 +1,33 @@
identity:
name: localtime_to_timestamp
author: zhuhao
label:
en_US: localtime to timestamp
zh_Hans: 获取时间戳
description:
human:
en_US: A tool for localtime convert to timestamp
zh_Hans: 获取时间戳
llm: A tool for localtime convert to timestamp
parameters:
- name: localtime
type: string
required: true
form: llm
label:
en_US: localtime
zh_Hans: 本地时间
human_description:
en_US: localtime, such as 2024-1-1 0:0:0
zh_Hans: 本地时间,比如 2024-1-1 0:0:0
- name: timezone
type: string
required: false
form: llm
label:
en_US: Timezone
zh_Hans: 时区
human_description:
en_US: Timezone, such as Asia/Shanghai
zh_Hans: 时区,比如 Asia/Shanghai
default: Asia/Shanghai

View File

@@ -0,0 +1,49 @@
from collections.abc import Generator
from datetime import datetime
from typing import Any
import pytz
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolInvokeError
class TimestampToLocaltimeTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
Convert timestamp to localtime
"""
timestamp: int = tool_parameters.get("timestamp", 0)
timezone = tool_parameters.get("timezone", "Asia/Shanghai")
if not timezone:
timezone = None
time_format = "%Y-%m-%d %H:%M:%S"
locatime = self.timestamp_to_localtime(timestamp, timezone)
if not locatime:
yield self.create_text_message(f"Invalid timestamp: {timestamp}")
return
localtime_format = locatime.strftime(time_format)
yield self.create_text_message(f"{localtime_format}")
@staticmethod
def timestamp_to_localtime(timestamp: int, local_tz=None) -> datetime | None:
try:
if local_tz is None:
local_tz = datetime.now().astimezone().tzinfo
if isinstance(local_tz, str):
local_tz = pytz.timezone(local_tz)
local_time = datetime.fromtimestamp(timestamp, local_tz)
return local_time
except Exception as e:
raise ToolInvokeError(str(e))

View File

@@ -0,0 +1,33 @@
identity:
name: timestamp_to_localtime
author: zhuhao
label:
en_US: Timestamp to localtime
zh_Hans: 时间戳转换
description:
human:
en_US: A tool for timestamp convert to localtime
zh_Hans: 时间戳转换
llm: A tool for timestamp convert to localtime
parameters:
- name: timestamp
type: number
required: true
form: llm
label:
en_US: Timestamp
zh_Hans: 时间戳
human_description:
en_US: Timestamp
zh_Hans: 时间戳
- name: timezone
type: string
required: false
form: llm
label:
en_US: Timezone
zh_Hans: 时区
human_description:
en_US: Timezone, such as Asia/Shanghai
zh_Hans: 时区,比如 Asia/Shanghai
default: Asia/Shanghai

View File

@@ -0,0 +1,53 @@
from collections.abc import Generator
from datetime import datetime
from typing import Any
import pytz
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolInvokeError
class TimezoneConversionTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
Convert time to equivalent time zone
"""
current_time = tool_parameters.get("current_time")
current_timezone = tool_parameters.get("current_timezone", "Asia/Shanghai")
target_timezone = tool_parameters.get("target_timezone", "Asia/Tokyo")
target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore
if not target_time:
yield self.create_text_message(
f"Invalid datetime and timezone: {current_time},{current_timezone},{target_timezone}"
)
return
yield self.create_text_message(f"{target_time}")
@staticmethod
def timezone_convert(current_time: str, source_timezone: str, target_timezone: str) -> str:
"""
Convert a time string from source timezone to target timezone.
"""
time_format = "%Y-%m-%d %H:%M:%S"
try:
# get source timezone
input_timezone = pytz.timezone(source_timezone)
# get target timezone
output_timezone = pytz.timezone(target_timezone)
local_time = datetime.strptime(current_time, time_format)
datetime_with_tz = input_timezone.localize(local_time)
# timezone convert
converted_datetime = datetime_with_tz.astimezone(output_timezone)
return converted_datetime.strftime(time_format)
except Exception as e:
raise ToolInvokeError(str(e))

View File

@@ -0,0 +1,44 @@
identity:
name: timezone_conversion
author: zhuhao
label:
en_US: convert time to equivalent time zone
zh_Hans: 时区转换
description:
human:
en_US: A tool to convert time to equivalent time zone
zh_Hans: 时区转换
llm: A tool to convert time to equivalent time zone
parameters:
- name: current_time
type: string
required: true
form: llm
label:
en_US: current time
zh_Hans: 当前时间
human_description:
en_US: current time, such as 2024-1-1 0:0:0
zh_Hans: 当前时间,比如 2024-1-1 0:0:0
- name: current_timezone
type: string
required: true
form: llm
label:
en_US: Current Timezone
zh_Hans: 当前时区
human_description:
en_US: Current Timezone, such as Asia/Shanghai
zh_Hans: 当前时区,比如 Asia/Shanghai
default: Asia/Shanghai
- name: target_timezone
type: string
required: true
form: llm
label:
en_US: Target Timezone
zh_Hans: 目标时区
human_description:
en_US: Target Timezone, such as Asia/Tokyo
zh_Hans: 目标时区,比如 Asia/Tokyo
default: Asia/Tokyo

View File

@@ -0,0 +1,50 @@
import calendar
from collections.abc import Generator
from datetime import datetime
from typing import Any
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
class WeekdayTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
Calculate the day of the week for a given date
"""
year = tool_parameters.get("year")
month = tool_parameters.get("month")
if month is None:
raise ValueError("Month is required")
day = tool_parameters.get("day")
date_obj = self.convert_datetime(year, month, day)
if not date_obj:
yield self.create_text_message(f"Invalid date: Year {year}, Month {month}, Day {day}.")
return
weekday_name = calendar.day_name[date_obj.weekday()]
month_name = calendar.month_name[month]
readable_date = f"{month_name} {date_obj.day}, {date_obj.year}"
yield self.create_text_message(f"{readable_date} is {weekday_name}.")
@staticmethod
def convert_datetime(year, month, day) -> datetime | None:
try:
# allowed range in datetime module
if not (year >= 1 and 1 <= month <= 12 and 1 <= day <= 31):
return None
year = int(year)
month = int(month)
day = int(day)
return datetime(year, month, day)
except ValueError:
return None

View File

@@ -0,0 +1,42 @@
identity:
name: weekday
author: Bowen Liang
label:
en_US: Weekday Calculator
zh_Hans: 星期几计算器
description:
human:
en_US: A tool for calculating the weekday of a given date.
zh_Hans: 计算指定日期为星期几的工具。
llm: A tool for calculating the weekday of a given date by year, month and day.
parameters:
- name: year
type: number
required: true
form: llm
label:
en_US: Year
zh_Hans:
human_description:
en_US: Year
zh_Hans:
- name: month
type: number
required: true
form: llm
label:
en_US: Month
zh_Hans:
human_description:
en_US: Month
zh_Hans:
- name: day
type: number
required: true
form: llm
label:
en_US: day
zh_Hans:
human_description:
en_US: day
zh_Hans:

View File

@@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="17" viewBox="0 0 16 17" fill="none">
<path fill-rule="evenodd" clip-rule="evenodd" d="M2.6665 1.16667C1.56193 1.16667 0.666504 2.0621 0.666504 3.16667C0.666504 4.27124 1.56193 5.16667 2.6665 5.16667C2.79161 5.16667 2.91403 5.15519 3.03277 5.13321C2.3808 6.09319 1.99984 7.25211 1.99984 8.5C1.99984 9.7479 2.3808 10.9068 3.03277 11.8668C2.91403 11.8448 2.79161 11.8333 2.6665 11.8333C1.56193 11.8333 0.666504 12.7288 0.666504 13.8333C0.666504 14.9379 1.56193 15.8333 2.6665 15.8333C3.77107 15.8333 4.6665 14.9379 4.6665 13.8333C4.6665 13.7082 4.65502 13.5858 4.63304 13.4671C5.59302 14.119 6.75194 14.5 7.99984 14.5C9.24773 14.5 10.4066 14.119 11.3666 13.4671C11.3447 13.5858 11.3332 13.7082 11.3332 13.8333C11.3332 14.9379 12.2286 15.8333 13.3332 15.8333C14.4377 15.8333 15.3332 14.9379 15.3332 13.8333C15.3332 12.7288 14.4377 11.8333 13.3332 11.8333C13.2081 11.8333 13.0856 11.8448 12.9669 11.8668C13.6189 10.9068 13.9998 9.7479 13.9998 8.5C13.9998 7.25211 13.6189 6.09319 12.9669 5.13321C13.0856 5.15519 13.2081 5.16667 13.3332 5.16667C14.4377 5.16667 15.3332 4.27124 15.3332 3.16667C15.3332 2.0621 14.4377 1.16667 13.3332 1.16667C12.2286 1.16667 11.3332 2.0621 11.3332 3.16667C11.3332 3.29177 11.3447 3.41419 11.3666 3.53293C10.4066 2.88097 9.24773 2.50001 7.99984 2.50001C6.75194 2.50001 5.59302 2.88097 4.63304 3.53293C4.65502 3.41419 4.6665 3.29177 4.6665 3.16667C4.6665 2.0621 3.77107 1.16667 2.6665 1.16667ZM3.38043 7.83334C3.63081 6.08287 4.85262 4.64578 6.48223 4.08565C5.79223 5.22099 5.36488 6.50185 5.23815 7.83334H3.38043ZM6.48228 12.9144C4.85264 12.3543 3.63082 10.9172 3.38043 9.16667H5.23815C5.3649 10.4982 5.79226 11.779 6.48228 12.9144ZM12.6192 9.16667C12.3689 10.9168 11.1475 12.3537 9.5183 12.9141C10.2082 11.7788 10.6355 10.498 10.7622 9.16667H12.6192ZM9.51834 4.08596C11.1475 4.64631 12.3689 6.0832 12.6192 7.83334H10.7622C10.6355 6.50197 10.2082 5.22123 9.51834 4.08596ZM9.4218 7.83334C9.27457 6.52262 8.78381 5.27411 8.00019 4.2145C7.21658 5.27411 6.72582 6.52262 6.57859 7.83334H9.4218ZM6.5786 9.16667C6.72583 10.4774 7.21659 11.7259 8.00019 12.7855C8.7838 11.7259 9.27456 10.4774 9.42179 9.16667H6.5786Z" fill="#DD2590"/>
</svg>

After

Width:  |  Height:  |  Size: 2.2 KiB

View File

@@ -0,0 +1,39 @@
from collections.abc import Generator
from typing import Any
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolInvokeError
from core.tools.utils.web_reader_tool import get_url
class WebscraperTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke tools
"""
try:
url = tool_parameters.get("url", "")
user_agent = tool_parameters.get("user_agent", "")
if not url:
yield self.create_text_message("Please input url")
return
# get webpage
result = get_url(url, user_agent=user_agent)
if tool_parameters.get("generate_summary"):
# summarize and return
yield self.create_text_message(self.summary(user_id=user_id, content=result))
else:
# return full webpage
yield self.create_text_message(result)
except Exception as e:
raise ToolInvokeError(str(e))

View File

@@ -0,0 +1,60 @@
identity:
name: webscraper
author: Dify
label:
en_US: Web Scraper
zh_Hans: 网页爬虫
pt_BR: Web Scraper
description:
human:
en_US: A tool for scraping webpages.
zh_Hans: 一个用于爬取网页的工具。
pt_BR: A tool for scraping webpages.
llm: A tool for scraping webpages. Input should be a URL.
parameters:
- name: url
type: string
required: true
label:
en_US: URL
zh_Hans: 网页链接
pt_BR: URL
human_description:
en_US: used for linking to webpages
zh_Hans: 用于链接到网页
pt_BR: used for linking to webpages
llm_description: url for scraping
form: llm
- name: user_agent
type: string
required: false
label:
en_US: User Agent
zh_Hans: User Agent
pt_BR: User Agent
human_description:
en_US: used for identifying the browser.
zh_Hans: 用于识别浏览器。
pt_BR: used for identifying the browser.
form: form
default: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.1000.0 Safari/537.36
- name: generate_summary
type: boolean
required: false
label:
en_US: Whether to generate summary
zh_Hans: 是否生成摘要
human_description:
en_US: If true, the crawler will only return the page summary content.
zh_Hans: 如果启用,爬虫将仅返回页面摘要内容。
form: form
options:
- value: "true"
label:
en_US: "Yes"
zh_Hans:
- value: "false"
label:
en_US: "No"
zh_Hans:
default: "false"

View File

@@ -0,0 +1,11 @@
from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController
class WebscraperProvider(BuiltinToolProviderController):
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
"""
Validate credentials
"""
pass

View File

@@ -0,0 +1,15 @@
identity:
author: Dify
name: webscraper
label:
en_US: WebScraper
zh_Hans: 网页抓取
pt_BR: WebScraper
description:
en_US: Web Scrapper tool kit is used to scrape web
zh_Hans: 一个用于抓取网页的工具。
pt_BR: Web Scrapper tool kit is used to scrape web
icon: icon.svg
tags:
- productivity
credentials_for_provider: {}

View File

@@ -0,0 +1,148 @@
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
_SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
retain the original meaning and keep the key points.
however, the text you got is too long, what you got is possible a part of the text.
Please summarize the text you got.
"""
class BuiltinTool(Tool):
"""
Builtin tool
:param meta: the meta data of a tool call processing
"""
def __init__(self, provider: str, **kwargs):
super().__init__(**kwargs)
self.provider = provider
def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
"""
fork a new tool with metadata
:return: the new tool
"""
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
provider=self.provider,
)
def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]) -> LLMResult:
"""
invoke model
:param user_id: the user id
:param prompt_messages: the prompt messages
:param stop: the stop words
:return: the model result
"""
# invoke model
return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id or "",
tool_type="builtin",
tool_name=self.entity.identity.name,
prompt_messages=prompt_messages,
)
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.BUILT_IN
def get_max_tokens(self) -> int:
"""
get max tokens
:return: the max tokens
"""
if self.runtime is None:
raise ValueError("runtime is required")
return ModelInvocationUtils.get_max_llm_context_tokens(
tenant_id=self.runtime.tenant_id or "",
)
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
"""
get prompt tokens
:param prompt_messages: the prompt messages
:return: the tokens
"""
if self.runtime is None:
raise ValueError("runtime is required")
return ModelInvocationUtils.calculate_tokens(
tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
)
def summary(self, user_id: str, content: str) -> str:
max_tokens = self.get_max_tokens()
if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=content)]) < max_tokens * 0.6:
return content
def get_prompt_tokens(content: str) -> int:
return self.get_prompt_tokens(
prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)]
)
def summarize(content: str) -> str:
summary = self.invoke_model(
user_id=user_id,
prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)],
stop=[],
)
assert isinstance(summary.message.content, str)
return summary.message.content
lines = content.split("\n")
new_lines = []
# split long line into multiple lines
for i in range(len(lines)):
line = lines[i]
if not line.strip():
continue
if len(line) < max_tokens * 0.5:
new_lines.append(line)
elif get_prompt_tokens(line) > max_tokens * 0.7:
while get_prompt_tokens(line) > max_tokens * 0.7:
new_lines.append(line[: int(max_tokens * 0.5)])
line = line[int(max_tokens * 0.5) :]
new_lines.append(line)
else:
new_lines.append(line)
# merge lines into messages with max tokens
messages: list[str] = []
for j in new_lines:
if len(messages) == 0:
messages.append(j)
else:
if len(messages[-1]) + len(j) < max_tokens * 0.5:
messages[-1] += j
if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7:
messages.append(j)
else:
messages[-1] += j
summaries = []
for i in range(len(messages)):
message = messages[i]
summary = summarize(message)
summaries.append(summary)
result = "\n".join(summaries)
if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=result)]) > max_tokens * 0.7:
return self.summary(user_id=user_id, content=result)
return result

View File

@@ -0,0 +1,209 @@
from pydantic import Field
from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.custom_tool.tool import ApiTool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolDescription,
ToolEntity,
ToolIdentity,
ToolProviderEntity,
ToolProviderIdentity,
ToolProviderType,
)
from extensions.ext_database import db
from models.tools import ApiToolProvider
class ApiToolProviderController(ToolProviderController):
provider_id: str
tenant_id: str
tools: list[ApiTool] = Field(default_factory=list)
def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str):
super().__init__(entity)
self.provider_id = provider_id
self.tenant_id = tenant_id
self.tools = []
@classmethod
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
credentials_schema = [
ProviderConfig(
name="auth_type",
required=True,
type=ProviderConfig.Type.SELECT,
options=[
ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="")),
ProviderConfig.Option(value="api_key_header", label=I18nObject(en_US="Header", zh_Hans="请求头")),
ProviderConfig.Option(
value="api_key_query", label=I18nObject(en_US="Query Param", zh_Hans="查询参数")
),
],
default="none",
help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
)
]
if auth_type == ApiProviderAuthType.API_KEY_HEADER:
credentials_schema = [
*credentials_schema,
ProviderConfig(
name="api_key_header",
required=False,
default="Authorization",
type=ProviderConfig.Type.TEXT_INPUT,
help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
),
ProviderConfig(
name="api_key_value",
required=True,
type=ProviderConfig.Type.SECRET_INPUT,
help=I18nObject(en_US="The api key", zh_Hans="api key 的值"),
),
ProviderConfig(
name="api_key_header_prefix",
required=False,
default="basic",
type=ProviderConfig.Type.SELECT,
help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"),
options=[
ProviderConfig.Option(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
ProviderConfig.Option(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
],
),
]
elif auth_type == ApiProviderAuthType.API_KEY_QUERY:
credentials_schema = [
*credentials_schema,
ProviderConfig(
name="api_key_query_param",
required=False,
default="key",
type=ProviderConfig.Type.TEXT_INPUT,
help=I18nObject(
en_US="The query parameter name of the api key", zh_Hans="携带 api key 的查询参数名称"
),
),
ProviderConfig(
name="api_key_value",
required=True,
type=ProviderConfig.Type.SECRET_INPUT,
help=I18nObject(en_US="The api key", zh_Hans="api key 的值"),
),
]
elif auth_type == ApiProviderAuthType.NONE:
pass
user = db_provider.user
user_name = user.name if user else ""
return ApiToolProviderController(
entity=ToolProviderEntity(
identity=ToolProviderIdentity(
author=user_name,
name=db_provider.name,
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
icon=db_provider.icon,
),
credentials_schema=credentials_schema,
plugin_id=None,
),
provider_id=db_provider.id or "",
tenant_id=db_provider.tenant_id or "",
)
@property
def provider_type(self) -> ToolProviderType:
return ToolProviderType.API
def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
"""
parse tool bundle to tool
:param tool_bundle: the tool bundle
:return: the tool
"""
return ApiTool(
api_bundle=tool_bundle,
provider_id=self.provider_id,
entity=ToolEntity(
identity=ToolIdentity(
author=tool_bundle.author,
name=tool_bundle.operation_id or "default_tool",
label=I18nObject(
en_US=tool_bundle.operation_id or "default_tool",
zh_Hans=tool_bundle.operation_id or "default_tool",
),
icon=self.entity.identity.icon,
provider=self.provider_id,
),
description=ToolDescription(
human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""),
llm=tool_bundle.summary or "",
),
parameters=tool_bundle.parameters or [],
),
runtime=ToolRuntime(tenant_id=self.tenant_id),
)
def load_bundled_tools(self, tools: list[ApiToolBundle]):
"""
load bundled tools
:param tools: the bundled tools
:return: the tools
"""
self.tools = [self._parse_tool_bundle(tool) for tool in tools]
return self.tools
def get_tools(self, tenant_id: str) -> list[ApiTool]:
"""
fetch tools from database
:param tenant_id: the tenant id
:return: the tools
"""
if len(self.tools) > 0:
return self.tools
tools: list[ApiTool] = []
# get tenant api providers
db_providers = db.session.scalars(
select(ApiToolProvider).where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name
)
).all()
if db_providers and len(db_providers) != 0:
for db_provider in db_providers:
for tool in db_provider.tools:
assistant_tool = self._parse_tool_bundle(tool)
tools.append(assistant_tool)
self.tools = tools
return tools
def get_tool(self, tool_name: str) -> ApiTool:
"""
get tool by name
:param tool_name: the name of the tool
:return: the tool
"""
if self.tools is None:
self.get_tools(self.tenant_id)
for tool in self.tools:
if tool.entity.identity.name == tool_name:
return tool
raise ValueError(f"tool {tool_name} not found")

View File

@@ -0,0 +1,411 @@
import json
from collections.abc import Generator
from dataclasses import dataclass
from os import getenv
from typing import Any, Union
from urllib.parse import urlencode
import httpx
from core.file.file_manager import download
from core.helper import ssrf_proxy
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
API_TOOL_DEFAULT_TIMEOUT = (
int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")),
int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")),
)
@dataclass
class ParsedResponse:
"""Represents a parsed HTTP response with type information"""
content: Union[str, dict]
is_json: bool
def to_string(self) -> str:
"""Convert response to string format for credential validation"""
if isinstance(self.content, dict):
return json.dumps(self.content, ensure_ascii=False)
return str(self.content)
class ApiTool(Tool):
"""
Api tool
"""
def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime, provider_id: str):
super().__init__(entity, runtime)
self.api_bundle = api_bundle
self.provider_id = provider_id
def fork_tool_runtime(self, runtime: ToolRuntime):
"""
fork a new tool with metadata
:return: the new tool
"""
if self.api_bundle is None:
raise ValueError("api_bundle is required")
return self.__class__(
entity=self.entity,
api_bundle=self.api_bundle.model_copy(),
runtime=runtime,
provider_id=self.provider_id,
)
def validate_credentials(
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
) -> str:
"""
validate the credentials for Api tool
"""
# assemble validate request and request parameters
headers = self.assembling_request(parameters)
if format_only:
return ""
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
# validate response
parsed_response = self.validate_and_parse_response(response)
# For credential validation, always return as string
return parsed_response.to_string()
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.API
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
headers = {}
if self.runtime is None:
raise ToolProviderCredentialValidationError("runtime not initialized")
credentials = self.runtime.credentials or {}
if "auth_type" not in credentials:
raise ToolProviderCredentialValidationError("Missing auth_type")
if credentials["auth_type"] in ("api_key_header", "api_key"): # backward compatibility:
api_key_header = "Authorization"
if "api_key_header" in credentials:
api_key_header = credentials["api_key_header"]
if "api_key_value" not in credentials:
raise ToolProviderCredentialValidationError("Missing api_key_value")
elif not isinstance(credentials["api_key_value"], str):
raise ToolProviderCredentialValidationError("api_key_value must be a string")
if "api_key_header_prefix" in credentials:
api_key_header_prefix = credentials["api_key_header_prefix"]
if api_key_header_prefix == "basic" and credentials["api_key_value"]:
credentials["api_key_value"] = f"Basic {credentials['api_key_value']}"
elif api_key_header_prefix == "bearer" and credentials["api_key_value"]:
credentials["api_key_value"] = f"Bearer {credentials['api_key_value']}"
elif api_key_header_prefix == "custom":
pass
headers[api_key_header] = credentials["api_key_value"]
elif credentials["auth_type"] == "api_key_query":
# For query parameter authentication, we don't add anything to headers
# The query parameter will be added in do_http_request method
pass
needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required]
for parameter in needed_parameters:
if parameter.required and parameter.name not in parameters:
if parameter.default is not None:
parameters[parameter.name] = parameter.default
else:
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
return headers
def validate_and_parse_response(self, response: httpx.Response) -> ParsedResponse:
"""
validate the response and return parsed content with type information
:return: ParsedResponse with content and is_json flag
"""
if isinstance(response, httpx.Response):
if response.status_code >= 400:
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
if not response.content:
return ParsedResponse(
"Empty response from the tool, please check your parameters and try again.", False
)
# Check content type
content_type = response.headers.get("content-type", "").lower()
is_json_content_type = "application/json" in content_type
# Try to parse as JSON
try:
json_response = response.json()
# If content-type indicates JSON, return as JSON object
if is_json_content_type:
return ParsedResponse(json_response, True)
else:
# If content-type doesn't indicate JSON, treat as text regardless of content
return ParsedResponse(response.text, False)
except Exception:
# Not valid JSON, return as text
return ParsedResponse(response.text, False)
else:
raise ValueError(f"Invalid response type {type(response)}")
@staticmethod
def get_parameter_value(parameter, parameters):
if parameter["name"] in parameters:
return parameters[parameter["name"]]
elif parameter.get("required", False):
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
return (parameter.get("schema", {}) or {}).get("default", "")
def do_http_request(
self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]
) -> httpx.Response:
"""
do http request depending on api bundle
"""
method = method.lower()
params = {}
path_params = {}
# FIXME: body should be a dict[str, Any] but it changed a lot in this function
body: Any = {}
cookies = {}
files = []
# Add API key to query parameters if auth_type is api_key_query
if self.runtime and self.runtime.credentials:
credentials = self.runtime.credentials
if credentials.get("auth_type") == "api_key_query":
api_key_query_param = credentials.get("api_key_query_param", "key")
api_key_value = credentials.get("api_key_value")
if api_key_value:
params[api_key_query_param] = api_key_value
# check parameters
for parameter in self.api_bundle.openapi.get("parameters", []):
value = self.get_parameter_value(parameter, parameters)
if parameter["in"] == "path":
path_params[parameter["name"]] = value
elif parameter["in"] == "query":
if value != "":
params[parameter["name"]] = value
elif parameter["in"] == "cookie":
cookies[parameter["name"]] = value
elif parameter["in"] == "header":
headers[parameter["name"]] = str(value)
# check if there is a request body and handle it
if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None:
# handle json request body
if "content" in self.api_bundle.openapi["requestBody"]:
for content_type in self.api_bundle.openapi["requestBody"]["content"]:
headers["Content-Type"] = content_type
body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"]
# handle ref schema
if "$ref" in body_schema:
ref_path = body_schema["$ref"].split("/")
ref_name = ref_path[-1]
if (
"components" in self.api_bundle.openapi
and "schemas" in self.api_bundle.openapi["components"]
):
if ref_name in self.api_bundle.openapi["components"]["schemas"]:
body_schema = self.api_bundle.openapi["components"]["schemas"][ref_name]
required = body_schema.get("required", [])
properties = body_schema.get("properties", {})
for name, property in properties.items():
if name in parameters:
# multiple file upload: if the type is array and the items have format as binary
if property.get("type") == "array" and property.get("items", {}).get("format") == "binary":
# parameters[name] should be a list of file objects.
for f in parameters[name]:
files.append((name, (f.filename, download(f), f.mime_type)))
elif property.get("format") == "binary":
f = parameters[name]
files.append((name, (f.filename, download(f), f.mime_type)))
elif "$ref" in property:
body[name] = parameters[name]
else:
# convert type
body[name] = self._convert_body_property_type(property, parameters[name])
elif name in required:
raise ToolParameterValidationError(
f"Missing required parameter {name} in operation {self.api_bundle.operation_id}"
)
elif "default" in property:
body[name] = property["default"]
else:
# omit optional parameters that weren't provided, instead of setting them to None
pass
break
# replace path parameters
for name, value in path_params.items():
url = url.replace(f"{{{name}}}", f"{value}")
# parse http body data if needed
if "Content-Type" in headers:
if headers["Content-Type"] == "application/json":
body = json.dumps(body)
elif headers["Content-Type"] == "application/x-www-form-urlencoded":
body = urlencode(body)
else:
body = body
# if there is a file upload, remove the Content-Type header
# so that httpx can automatically generate the boundary header required for multipart/form-data.
# issue: https://github.com/langgenius/dify/issues/13684
# reference: https://stackoverflow.com/questions/39280438/fetch-missing-boundary-in-multipart-form-data-post
if files:
headers.pop("Content-Type", None)
_METHOD_MAP = {
"get": ssrf_proxy.get,
"head": ssrf_proxy.head,
"post": ssrf_proxy.post,
"put": ssrf_proxy.put,
"delete": ssrf_proxy.delete,
"patch": ssrf_proxy.patch,
}
method_lc = method.lower()
if method_lc not in _METHOD_MAP:
raise ValueError(f"Invalid http method {method}")
response: httpx.Response = _METHOD_MAP[
method_lc
]( # https://discuss.python.org/t/type-inference-for-function-return-types/42926
url,
max_retries=0,
params=params,
headers=headers,
cookies=cookies,
data=body,
files=files,
timeout=API_TOOL_DEFAULT_TIMEOUT,
follow_redirects=True,
)
return response
def _convert_body_property_any_of(
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
):
if max_recursive <= 0:
raise Exception("Max recursion depth reached")
for option in any_of or []:
try:
if "type" in option:
# Attempt to convert the value based on the type.
if option["type"] == "integer" or option["type"] == "int":
return int(value)
elif option["type"] == "number":
if "." in str(value):
return float(value)
else:
return int(value)
elif option["type"] == "string":
return str(value)
elif option["type"] == "boolean":
if str(value).lower() in {"true", "1"}:
return True
elif str(value).lower() in {"false", "0"}:
return False
else:
continue # Not a boolean, try next option
elif option["type"] == "null" and not value:
return None
else:
continue # Unsupported type, try next option
elif "anyOf" in option and isinstance(option["anyOf"], list):
# Recursive call to handle nested anyOf
return self._convert_body_property_any_of(property, value, option["anyOf"], max_recursive - 1)
except ValueError:
continue # Conversion failed, try next option
# If no option succeeded, you might want to return the value as is or raise an error
return value # or raise ValueError(f"Cannot convert value '{value}' to any specified type in anyOf")
def _convert_body_property_type(self, property: dict[str, Any], value: Any):
try:
if "type" in property:
if property["type"] == "integer" or property["type"] == "int":
return int(value)
elif property["type"] == "number":
# check if it is a float
if "." in str(value):
return float(value)
else:
return int(value)
elif property["type"] == "string":
return str(value)
elif property["type"] == "boolean":
return bool(value)
elif property["type"] == "null":
if value is None:
return None
elif property["type"] == "object" or property["type"] == "array":
if isinstance(value, str):
try:
return json.loads(value)
except ValueError:
return value
elif isinstance(value, dict):
return value
else:
return value
else:
raise ValueError(f"Invalid type {property['type']} for property {property}")
elif "anyOf" in property and isinstance(property["anyOf"], list):
return self._convert_body_property_any_of(property, value, property["anyOf"])
except ValueError:
return value
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke http request
"""
response: httpx.Response | str = ""
# assemble request
headers = self.assembling_request(tool_parameters)
# do http request
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)
# validate response
parsed_response = self.validate_and_parse_response(response)
# assemble invoke message based on response type
if parsed_response.is_json:
if isinstance(parsed_response.content, dict):
yield self.create_json_message(parsed_response.content)
# The yield below must be preserved to keep backward compatibility.
#
# ref: https://github.com/langgenius/dify/pull/23456#issuecomment-3182413088
yield self.create_text_message(response.text)
else:
# Convert to string if needed and create text message
text_response = (
parsed_response.content if isinstance(parsed_response.content, str) else str(parsed_response.content)
)
yield self.create_text_message(text_response)

View File

@@ -0,0 +1,132 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
class ToolApiEntity(BaseModel):
author: str
name: str # identifier
label: I18nObject # label
description: I18nObject
parameters: list[ToolParameter] | None = None
labels: list[str] = Field(default_factory=list)
output_schema: Mapping[str, object] = Field(default_factory=dict)
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow", "mcp"] | None
class ToolProviderApiEntity(BaseModel):
id: str
author: str
name: str # identifier
description: I18nObject
icon: str | Mapping[str, str]
icon_dark: str | Mapping[str, str] = ""
label: I18nObject # label
type: ToolProviderType
masked_credentials: Mapping[str, object] = Field(default_factory=dict)
original_credentials: Mapping[str, object] = Field(default_factory=dict)
is_team_authorization: bool = False
allow_delete: bool = True
plugin_id: str | None = Field(default="", description="The plugin id of the tool")
plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the tool")
tools: list[ToolApiEntity] = Field(default_factory=list[ToolApiEntity])
labels: list[str] = Field(default_factory=list)
# MCP
server_url: str | None = Field(default="", description="The server url of the tool")
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool")
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
authentication: MCPAuthentication | None = Field(default=None, description="The OAuth config of the MCP tool")
is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered")
configuration: MCPConfiguration | None = Field(
default=None, description="The timeout and sse_read_timeout of the MCP tool"
)
@field_validator("tools", mode="before")
@classmethod
def convert_none_to_empty_list(cls, v):
return v if v is not None else []
def to_dict(self):
# -------------
# overwrite tool parameter types for temp fix
tools = jsonable_encoder(self.tools)
for tool in tools:
if tool.get("parameters"):
for parameter in tool.get("parameters"):
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES:
parameter["type"] = "files"
if parameter.get("input_schema") is None:
parameter.pop("input_schema", None)
# -------------
optional_fields = self.optional_field("server_url", self.server_url)
if self.type == ToolProviderType.MCP:
optional_fields.update(self.optional_field("updated_at", self.updated_at))
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
optional_fields.update(
self.optional_field(
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
)
)
optional_fields.update(
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
)
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
return {
"id": self.id,
"author": self.author,
"name": self.name,
"plugin_id": self.plugin_id,
"plugin_unique_identifier": self.plugin_unique_identifier,
"description": self.description.to_dict(),
"icon": self.icon,
"icon_dark": self.icon_dark,
"label": self.label.to_dict(),
"type": self.type.value,
"team_credentials": self.masked_credentials,
"is_team_authorization": self.is_team_authorization,
"allow_delete": self.allow_delete,
"tools": tools,
"labels": self.labels,
**optional_fields,
}
def optional_field(self, key: str, value: Any):
"""Return dict with key-value if value is truthy, empty dict otherwise."""
return {key: value} if value else {}
class ToolProviderCredentialApiEntity(BaseModel):
id: str = Field(description="The unique id of the credential")
name: str = Field(description="The name of the credential")
provider: str = Field(description="The provider of the credential")
credential_type: CredentialType = Field(description="The type of the credential")
is_default: bool = Field(
default=False, description="Whether the credential is the default credential for the provider in the workspace"
)
credentials: Mapping[str, object] = Field(description="The credentials of the provider", default_factory=dict)
class ToolProviderCredentialInfoApiEntity(BaseModel):
supported_credential_types: list[CredentialType] = Field(
description="The supported credential types of the provider"
)
is_oauth_custom_client_enabled: bool = Field(
default=False, description="Whether the OAuth custom client is enabled for the provider"
)
credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")

View File

@@ -0,0 +1,22 @@
from pydantic import BaseModel, Field, model_validator
class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
en_US: str
zh_Hans: str | None = Field(default=None)
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
@model_validator(mode="after")
def _populate_missing_locales(self):
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
return self
def to_dict(self):
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@@ -0,0 +1 @@
TOOL_SELECTOR_MODEL_IDENTITY = "__dify__tool_selector__"

View File

@@ -0,0 +1,27 @@
from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolParameter
class ApiToolBundle(BaseModel):
"""
This class is used to store the schema information of an api based tool.
such as the url, the method, the parameters, etc.
"""
# server_url
server_url: str
# method
method: str
# summary
summary: str | None = None
# operation_id
operation_id: str | None = None
# parameters
parameters: list[ToolParameter] | None = None
# author
author: str
# icon
icon: str | None = None
# openapi operation
openapi: dict

View File

@@ -0,0 +1,492 @@
import base64
import contextlib
from collections.abc import Mapping
from enum import StrEnum, auto
from typing import Any, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.parameters import (
MCPServerParameterType,
PluginParameter,
PluginParameterOption,
PluginParameterType,
as_normal_type,
cast_parameter_value,
init_frontend_parameter,
)
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
class ToolLabelEnum(StrEnum):
SEARCH = "search"
IMAGE = "image"
VIDEOS = "videos"
WEATHER = "weather"
FINANCE = "finance"
DESIGN = "design"
TRAVEL = "travel"
SOCIAL = "social"
NEWS = "news"
MEDICAL = "medical"
PRODUCTIVITY = "productivity"
EDUCATION = "education"
BUSINESS = "business"
ENTERTAINMENT = "entertainment"
UTILITIES = "utilities"
RAG = "rag"
OTHER = "other"
class ToolProviderType(StrEnum):
"""
Enum class for tool provider
"""
PLUGIN = auto()
BUILT_IN = "builtin"
WORKFLOW = auto()
API = auto()
APP = auto()
DATASET_RETRIEVAL = "dataset-retrieval"
MCP = auto()
@classmethod
def value_of(cls, value: str) -> "ToolProviderType":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")
class ApiProviderSchemaType(StrEnum):
"""
Enum class for api provider schema type.
"""
OPENAPI = auto()
SWAGGER = auto()
OPENAI_PLUGIN = auto()
OPENAI_ACTIONS = auto()
@classmethod
def value_of(cls, value: str) -> "ApiProviderSchemaType":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")
class ApiProviderAuthType(StrEnum):
"""
Enum class for api provider auth type.
"""
NONE = auto()
API_KEY_HEADER = auto()
API_KEY_QUERY = auto()
@classmethod
def value_of(cls, value: str) -> "ApiProviderAuthType":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
# 'api_key' deprecated in PR #21656
# normalize & tiny alias for backward compatibility
v = (value or "").strip().lower()
if v == "api_key":
v = cls.API_KEY_HEADER
for mode in cls:
if mode.value == v:
return mode
valid = ", ".join(m.value for m in cls)
raise ValueError(f"invalid mode value '{value}', expected one of: {valid}")
class ToolInvokeMessage(BaseModel):
class TextMessage(BaseModel):
text: str
class JsonMessage(BaseModel):
json_object: dict
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
class BlobMessage(BaseModel):
blob: bytes
class BlobChunkMessage(BaseModel):
id: str = Field(..., description="The id of the blob")
sequence: int = Field(..., description="The sequence of the chunk")
total_length: int = Field(..., description="The total length of the blob")
blob: bytes = Field(..., description="The blob data of the chunk")
end: bool = Field(..., description="Whether the chunk is the last chunk")
class FileMessage(BaseModel):
pass
class VariableMessage(BaseModel):
variable_name: str = Field(..., description="The name of the variable")
variable_value: Any = Field(..., description="The value of the variable")
stream: bool = Field(default=False, description="Whether the variable is streamed")
@model_validator(mode="before")
@classmethod
def transform_variable_value(cls, values):
"""
Only basic types and lists are allowed.
"""
value = values.get("variable_value")
if not isinstance(value, dict | list | str | int | float | bool):
raise ValueError("Only basic types and lists are allowed.")
# if stream is true, the value must be a string
if values.get("stream"):
if not isinstance(value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
return values
@field_validator("variable_name", mode="before")
@classmethod
def transform_variable_name(cls, value: str) -> str:
"""
The variable name must be a string.
"""
if value in {"json", "text", "files"}:
raise ValueError(f"The variable name '{value}' is reserved.")
return value
class LogMessage(BaseModel):
class LogStatus(StrEnum):
START = auto()
ERROR = auto()
SUCCESS = auto()
id: str
label: str = Field(..., description="The label of the log")
parent_id: str | None = Field(default=None, description="Leave empty for root log")
error: str | None = Field(default=None, description="The error message")
status: LogStatus = Field(..., description="The status of the log")
data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Mapping[str, Any] = Field(default_factory=dict, description="The metadata of the log")
@field_validator("metadata", mode="before")
@classmethod
def _normalize_metadata(cls, value: Mapping[str, Any] | None) -> Mapping[str, Any]:
return value or {}
class RetrieverResourceMessage(BaseModel):
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class MessageType(StrEnum):
TEXT = auto()
IMAGE = auto()
LINK = auto()
BLOB = auto()
JSON = auto()
IMAGE_LINK = auto()
BINARY_LINK = auto()
VARIABLE = auto()
FILE = auto()
LOG = auto()
BLOB_CHUNK = auto()
RETRIEVER_RESOURCES = auto()
type: MessageType = MessageType.TEXT
"""
plain text, image url or link url
"""
message: (
JsonMessage
| TextMessage
| BlobChunkMessage
| BlobMessage
| LogMessage
| FileMessage
| None
| VariableMessage
| RetrieverResourceMessage
)
meta: dict[str, Any] | None = None
@field_validator("message", mode="before")
@classmethod
def decode_blob_message(cls, v):
if isinstance(v, dict) and "blob" in v:
with contextlib.suppress(Exception):
v["blob"] = base64.b64decode(v["blob"])
return v
@field_serializer("message")
def serialize_message(self, v):
if isinstance(v, self.BlobMessage):
return {"blob": base64.b64encode(v.blob).decode("utf-8")}
return v
class ToolInvokeMessageBinary(BaseModel):
mimetype: str = Field(..., description="The mimetype of the binary")
url: str = Field(..., description="The url of the binary")
file_var: dict[str, Any] | None = None
class ToolParameter(PluginParameter):
"""
Overrides type
"""
class ToolParameterType(StrEnum):
"""
removes TOOLS_SELECTOR from PluginParameterType
"""
STRING = PluginParameterType.STRING
NUMBER = PluginParameterType.NUMBER
BOOLEAN = PluginParameterType.BOOLEAN
SELECT = PluginParameterType.SELECT
SECRET_INPUT = PluginParameterType.SECRET_INPUT
FILE = PluginParameterType.FILE
FILES = PluginParameterType.FILES
CHECKBOX = PluginParameterType.CHECKBOX
APP_SELECTOR = PluginParameterType.APP_SELECTOR
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR
ANY = PluginParameterType.ANY
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT
# MCP object and array type parameters
ARRAY = MCPServerParameterType.ARRAY
OBJECT = MCPServerParameterType.OBJECT
# deprecated, should not use.
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES
def as_normal_type(self):
return as_normal_type(self)
def cast_value(self, value: Any):
return cast_parameter_value(self, value)
class ToolParameterForm(StrEnum):
SCHEMA = auto() # should be set while adding tool
FORM = auto() # should be set before invoking tool
LLM = auto() # will be set by LLM
type: ToolParameterType = Field(..., description="The type of the parameter")
human_description: I18nObject | None = Field(default=None, description="The description presented to the user")
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
llm_description: str | None = None
# MCP object and array type parameters use this field to store the schema
input_schema: dict | None = None
@classmethod
def get_simple_instance(
cls,
name: str,
llm_description: str,
typ: ToolParameterType,
required: bool,
options: list[str] | None = None,
) -> "ToolParameter":
"""
get a simple tool parameter
:param name: the name of the parameter
:param llm_description: the description presented to the LLM
:param typ: the type of the parameter
:param required: if the parameter is required
:param options: the options of the parameter
"""
# convert options to ToolParameterOption
if options:
option_objs = [
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
for option in options
]
else:
option_objs = []
return cls(
name=name,
label=I18nObject(en_US="", zh_Hans=""),
placeholder=None,
human_description=I18nObject(en_US="", zh_Hans=""),
type=typ,
form=cls.ToolParameterForm.LLM,
llm_description=llm_description,
required=required,
options=option_objs,
)
def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value)
class ToolProviderIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
description: I18nObject = Field(..., description="The description of the tool")
icon: str = Field(..., description="The icon of the tool")
icon_dark: str | None = Field(default=None, description="The dark icon of the tool")
label: I18nObject = Field(..., description="The label of the tool")
tags: list[ToolLabelEnum] | None = Field(
default=[],
description="The tags of the tool",
)
class ToolIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
label: I18nObject = Field(..., description="The label of the tool")
provider: str = Field(..., description="The provider of the tool")
icon: str | None = None
class ToolDescription(BaseModel):
human: I18nObject = Field(..., description="The description presented to the user")
llm: str = Field(..., description="The description presented to the LLM")
class ToolEntity(BaseModel):
identity: ToolIdentity
parameters: list[ToolParameter] = Field(default_factory=list[ToolParameter])
description: ToolDescription | None = None
output_schema: Mapping[str, object] = Field(default_factory=dict)
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
return v or []
@field_validator("output_schema", mode="before")
@classmethod
def _normalize_output_schema(cls, value: Mapping[str, object] | None) -> Mapping[str, object]:
return value or {}
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(
default_factory=list[ProviderConfig], description="The schema of the OAuth client"
)
credentials_schema: list[ProviderConfig] = Field(
default_factory=list[ProviderConfig], description="The schema of the OAuth credentials"
)
class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
plugin_id: str | None = None
credentials_schema: list[ProviderConfig] = Field(default_factory=list[ProviderConfig])
oauth_schema: OAuthSchema | None = None
class ToolProviderEntityWithPlugin(ToolProviderEntity):
tools: list[ToolEntity] = Field(default_factory=list[ToolEntity])
class WorkflowToolParameterConfiguration(BaseModel):
"""
Workflow tool configuration
"""
name: str = Field(..., description="The name of the parameter")
description: str = Field(..., description="The description of the parameter")
form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
class ToolInvokeMeta(BaseModel):
"""
Tool invoke meta
"""
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: str | None = None
tool_config: dict | None = None
@classmethod
def empty(cls) -> "ToolInvokeMeta":
"""
Get an empty instance of ToolInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> "ToolInvokeMeta":
"""
Get an instance of ToolInvokeMeta with error
"""
return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self):
return {
"time_cost": self.time_cost,
"error": self.error,
"tool_config": self.tool_config,
}
class ToolLabel(BaseModel):
"""
Tool label
"""
name: str = Field(..., description="The name of the tool")
label: I18nObject = Field(..., description="The label of the tool")
icon: str = Field(..., description="The icon of the tool")
class ToolInvokeFrom(StrEnum):
"""
Enum class for tool invoke
"""
WORKFLOW = auto()
AGENT = auto()
PLUGIN = auto()
class ToolSelector(BaseModel):
dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY
class Parameter(BaseModel):
name: str = Field(..., description="The name of the parameter")
type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter")
required: bool = Field(..., description="Whether the parameter is required")
description: str = Field(..., description="The description of the parameter")
default: Union[int, float, str] | None = None
options: list[PluginParameterOption] | None = None
provider_id: str = Field(..., description="The id of the provider")
credential_id: str | None = Field(default=None, description="The id of the credential")
tool_name: str = Field(..., description="The name of the tool")
tool_description: str = Field(..., description="The description of the tool")
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
tool_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm")
def to_plugin_parameter(self) -> dict[str, Any]:
return self.model_dump()

View File

@@ -0,0 +1,117 @@
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum
ICONS = {
ToolLabelEnum.SEARCH: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M7.33398 1.3335C10.646 1.3335 13.334 4.0215 13.334 7.3335C13.334 10.6455 10.646 13.3335 7.33398 13.3335C4.02198 13.3335 1.33398 10.6455 1.33398 7.3335C1.33398 4.0215 4.02198 1.3335 7.33398 1.3335ZM7.33398 12.0002C9.91232 12.0002 12.0007 9.91183 12.0007 7.3335C12.0007 4.75516 9.91232 2.66683 7.33398 2.66683C4.75565 2.66683 2.66732 4.75516 2.66732 7.3335C2.66732 9.91183 4.75565 12.0002 7.33398 12.0002ZM12.9909 12.0476L14.8764 13.9332L13.9337 14.876L12.0481 12.9904L12.9909 12.0476Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.IMAGE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M13.0514 9.71752L10.4718 7.13792C10.2115 6.87752 9.78932 6.87752 9.52898 7.13792L4.57721 12.0897C3.4097 11.1113 2.66732 9.64232 2.66732 7.99992C2.66732 5.0544 5.05513 2.66659 8.00065 2.66659C10.9462 2.66659 13.334 5.0544 13.334 7.99992C13.334 8.60085 13.2346 9.17852 13.0514 9.71752ZM5.72683 12.8257L10.0004 8.55212L12.4259 10.9777C11.4668 12.4001 9.84152 13.3331 8.00038 13.3331C7.18632 13.3331 6.41628 13.1511 5.72683 12.8257ZM8.00065 14.6666C11.6825 14.6666 14.6673 11.6818 14.6673 7.99992C14.6673 4.31802 11.6825 1.33325 8.00065 1.33325C4.31875 1.33325 1.33398 4.31802 1.33398 7.99992C1.33398 11.6818 4.31875 14.6666 8.00065 14.6666ZM7.33398 6.66658C7.33398 7.40299 6.73705 7.99992 6.00065 7.99992C5.26427 7.99992 4.66732 7.40299 4.66732 6.66658C4.66732 5.9302 5.26427 5.33325 6.00065 5.33325C6.73705 5.33325 7.33398 5.9302 7.33398 6.66658Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.VIDEOS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.00065 13.3333H13.334V14.6666H8.00065C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992C14.6673 9.50072 14.1714 10.8857 13.3345 11.9999H11.5284C12.6356 11.0227 13.334 9.59285 13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333ZM8.00065 6.66658C7.26425 6.66658 6.66732 6.06963 6.66732 5.33325C6.66732 4.59687 7.26425 3.99992 8.00065 3.99992C8.73705 3.99992 9.33398 4.59687 9.33398 5.33325C9.33398 6.06963 8.73705 6.66658 8.00065 6.66658ZM5.33398 9.33325C4.5976 9.33325 4.00065 8.73632 4.00065 7.99992C4.00065 7.26352 4.5976 6.66658 5.33398 6.66658C6.07036 6.66658 6.66732 7.26352 6.66732 7.99992C6.66732 8.73632 6.07036 9.33325 5.33398 9.33325ZM10.6673 9.33325C9.93092 9.33325 9.33398 8.73632 9.33398 7.99992C9.33398 7.26352 9.93092 6.66658 10.6673 6.66658C11.4037 6.66658 12.0007 7.26352 12.0007 7.99992C12.0007 8.73632 11.4037 9.33325 10.6673 9.33325ZM8.00065 11.9999C7.26425 11.9999 6.66732 11.403 6.66732 10.6666C6.66732 9.93018 7.26425 9.33325 8.00065 9.33325C8.73705 9.33325 9.33398 9.93018 9.33398 10.6666C9.33398 11.403 8.73705 11.9999 8.00065 11.9999Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.WEATHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M6.6553 3.37344C7.42088 2.1484 8.78162 1.3335 10.3327 1.3335C12.7259 1.3335 14.666 3.2736 14.666 5.66683C14.666 6.38704 14.4903 7.06623 14.1794 7.66383C14.8894 8.3325 15.3327 9.28123 15.3327 10.3335C15.3327 12.3586 13.6911 14.0002 11.666 14.0002H5.99935C3.05383 14.0002 0.666016 11.6124 0.666016 8.66683C0.666016 5.72131 3.05383 3.3335 5.99935 3.3335C6.22143 3.3335 6.44034 3.34707 6.6553 3.37344ZM8.03628 3.73629C9.37768 4.29108 10.4435 5.37735 10.9711 6.73256C11.1961 6.68943 11.4284 6.66683 11.666 6.66683C12.1561 6.66683 12.6237 6.76296 13.0511 6.93743C13.2317 6.55162 13.3327 6.12102 13.3327 5.66683C13.3327 4.00998 11.9895 2.66683 10.3327 2.66683C9.41115 2.66683 8.58662 3.08236 8.03628 3.73629ZM11.666 12.6668C12.9547 12.6668 13.9993 11.6222 13.9993 10.3335C13.9993 9.04483 12.9547 8.00016 11.666 8.00016C11.013 8.00016 10.4227 8.26836 9.99922 8.70063C9.99928 8.68936 9.99935 8.6781 9.99935 8.66683C9.99935 6.45769 8.20848 4.66683 5.99935 4.66683C3.79021 4.66683 1.99935 6.45769 1.99935 8.66683C1.99935 10.876 3.79021 12.6668 5.99935 12.6668H11.666Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.FINANCE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.00262 14.6685C4.32071 14.6685 1.33594 11.6838 1.33594 8.00184C1.33594 4.31997 4.32071 1.33521 8.00262 1.33521C11.6845 1.33521 14.6693 4.31997 14.6693 8.00184C14.6693 11.6838 11.6845 14.6685 8.00262 14.6685ZM8.00262 13.3352C10.9482 13.3352 13.336 10.9474 13.336 8.00184C13.336 5.05635 10.9482 2.66854 8.00262 2.66854C5.05708 2.66854 2.66927 5.05635 2.66927 8.00184C2.66927 10.9474 5.05708 13.3352 8.00262 13.3352ZM5.66927 9.33517H9.33595C9.52002 9.33517 9.66928 9.18597 9.66928 9.00184C9.66928 8.81777 9.52002 8.66851 9.33595 8.66851H6.66928C5.7488 8.66851 5.0026 7.92237 5.0026 7.00184C5.0026 6.08139 5.7488 5.33521 6.66928 5.33521H7.33595V4.00187H8.66928V5.33521H10.336V6.66851H6.66928C6.48518 6.66851 6.33594 6.81777 6.33594 7.00184C6.33594 7.18597 6.48518 7.33517 6.66928 7.33517H9.33595C10.2564 7.33517 11.0026 8.08137 11.0026 9.00184C11.0026 9.92237 10.2564 10.6685 9.33595 10.6685H8.66928V12.0018H7.33595V10.6685H5.66927V9.33517Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.DESIGN: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M4.70152 9.41416L3.2873 10.8284L5.17292 12.714L12.7154 5.17154L10.8298 3.28592L9.41557 4.70013L10.3584 5.64295L9.41557 6.58575L8.47277 5.64295L7.52997 6.58575L8.47277 7.52856L7.52997 8.47136L6.58713 7.52856L5.64433 8.47136L6.58713 9.41416L5.64433 10.357L4.70152 9.41416ZM11.3012 1.87171L14.1296 4.70013C14.39 4.96049 14.39 5.38259 14.1296 5.64295L5.64433 14.1282C5.38397 14.3886 4.96187 14.3886 4.70152 14.1282L1.87309 11.2998C1.61274 11.0394 1.61274 10.6174 1.87309 10.357L10.3584 1.87171C10.6187 1.61136 11.0408 1.61136 11.3012 1.87171ZM9.41557 12.2423L10.3584 11.2995L11.8534 12.7945H12.7962V11.8517L11.3012 10.3567L12.244 9.41383L14.0011 11.171V13.9999H11.1732L9.41557 12.2423ZM3.75861 6.58533L1.87299 4.69971C1.61265 4.43937 1.61265 4.01725 1.87299 3.75691L3.75861 1.87129C4.01896 1.61094 4.44107 1.61094 4.70142 1.87129L6.58704 3.75691L5.64423 4.69971L4.23002 3.2855L3.28721 4.22831L4.70142 5.64253L3.75861 6.58533Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.TRAVEL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M9.44839 2C9.80198 2 10.1411 2.14047 10.3912 2.39053L13.6101 5.60947C13.8602 5.85953 14.0007 6.19866 14.0007 6.55229V11.3333H15.334V12.6667L9.91652 12.6672C9.62032 13.8171 8.57638 14.6667 7.33398 14.6667C6.0916 14.6667 5.04766 13.8171 4.75146 12.6672L2.00065 12.6667C1.63246 12.6667 1.33398 12.3682 1.33398 12V3.33333C1.33398 2.59695 1.93094 2 2.66732 2H9.44839ZM7.33398 10.6667C6.5976 10.6667 6.00065 11.2636 6.00065 12C6.00065 12.7364 6.5976 13.3333 7.33398 13.3333C8.07038 13.3333 8.66732 12.7364 8.66732 12C8.66732 11.2636 8.07038 10.6667 7.33398 10.6667ZM9.44839 3.33333H2.66732V11.3333L4.75128 11.3335C5.04726 10.1833 6.09136 9.33333 7.33398 9.33333C8.57658 9.33333 9.62072 10.1833 9.91665 11.3335L12.6673 11.3333V6.55229L9.44839 3.33333ZM9.33398 4.66667V8.66667H4.00065V4.66667H9.33398ZM8.00065 6H5.33398V7.33333H8.00065V6Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.SOCIAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333C9.09518 13.3333 10.1127 13.0035 10.9594 12.438L11.699 13.5475C10.6408 14.2545 9.36885 14.6666 8.00065 14.6666C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992V8.99992C14.6673 10.2886 13.6227 11.3333 12.334 11.3333C11.5312 11.3333 10.8231 10.9278 10.4032 10.3105C9.79678 10.9409 8.94452 11.3333 8.00065 11.3333C6.1597 11.3333 4.66732 9.84085 4.66732 7.99992C4.66732 6.15897 6.1597 4.66658 8.00065 4.66658C8.75118 4.66658 9.44378 4.91464 10.001 5.33325H11.334V8.99992C11.334 9.55219 11.7817 9.99992 12.334 9.99992C12.8863 9.99992 13.334 9.55219 13.334 8.99992V7.99992ZM8.00065 5.99992C6.89605 5.99992 6.00065 6.89532 6.00065 7.99992C6.00065 9.10452 6.89605 9.99992 8.00065 9.99992C9.10525 9.99992 10.0007 9.10452 10.0007 7.99992C10.0007 6.89532 9.10525 5.99992 8.00065 5.99992Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.NEWS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M10.6673 13.3335V2.66683H2.66732V12.6668C2.66732 13.035 2.9658 13.3335 3.33398 13.3335H10.6673ZM12.6673 14.6668H3.33398C2.22942 14.6668 1.33398 13.7714 1.33398 12.6668V2.00016C1.33398 1.63198 1.63246 1.3335 2.00065 1.3335H11.334C11.7022 1.3335 12.0007 1.63198 12.0007 2.00016V6.66683H14.6673V12.6668C14.6673 13.7714 13.7719 14.6668 12.6673 14.6668ZM12.0007 8.00016V12.6668C12.0007 13.035 12.2991 13.3335 12.6673 13.3335C13.0355 13.3335 13.334 13.035 13.334 12.6668V8.00016H12.0007ZM4.00065 4.00016H8.00065V8.00016H4.00065V4.00016ZM5.33398 5.3335V6.66683H6.66732V5.3335H5.33398ZM4.00065 8.66683H9.33398V10.0002H4.00065V8.66683ZM4.00065 10.6668H9.33398V12.0002H4.00065V10.6668Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.MEDICAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.79747 1.51186L10.9641 5.26464C11.1482 5.5835 11.0389 5.99122 10.7201 6.17532L9.85373 6.67474L10.5207 7.83001L9.366 8.49668L8.699 7.34141L7.83333 7.84201C7.51447 8.02608 7.10673 7.91681 6.92267 7.59794L5.69747 5.47632C4.32922 5.89145 3.33333 7.16268 3.33333 8.66654C3.33333 9.08348 3.40987 9.48248 3.54965 9.85034C4.06613 9.52254 4.67762 9.33321 5.33333 9.33321C6.45605 9.33321 7.44913 9.88828 8.05313 10.7389L13.1787 7.78014L13.8454 8.93488L8.5932 11.9672C8.64133 12.1927 8.66667 12.4267 8.66667 12.6665C8.66667 12.895 8.64367 13.1181 8.59993 13.3337L14 13.3332V14.6665L2.66703 14.6673C2.2482 14.1101 2 13.4173 2 12.6665C2 11.9951 2.19855 11.3699 2.54014 10.8467C2.19517 10.1964 2 9.45428 2 8.66654C2 6.66968 3.25421 4.96575 5.01785 4.29953L4.75598 3.84519C4.38779 3.20747 4.60629 2.39202 5.24402 2.02382L6.97607 1.02382C7.6138 0.655637 8.42927 0.874138 8.79747 1.51186ZM5.33333 10.6665C4.22877 10.6665 3.33333 11.562 3.33333 12.6665C3.33333 12.9003 3.37343 13.1247 3.44711 13.3331H7.21953C7.29327 13.1247 7.33333 12.9003 7.33333 12.6665C7.33333 11.562 6.4379 10.6665 5.33333 10.6665ZM7.64273 2.17852L5.91068 3.17852L7.744 6.35395L9.47607 5.35395L7.64273 2.17852Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.PRODUCTIVITY: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M6.64807 11.9999H9.35062C9.43862 11.1989 9.84742 10.5376 10.5111 9.81499C10.5858 9.73365 11.0652 9.23752 11.1221 9.16665C11.6872 8.46199 11.9993 7.58992 11.9993 6.66659C11.9993 4.45745 10.2085 2.66659 7.99935 2.66659C5.79021 2.66659 3.99935 4.45745 3.99935 6.66659C3.99935 7.58945 4.31118 8.46105 4.87576 9.16552C4.93271 9.23659 5.41322 9.73405 5.48704 9.81445C6.15112 10.5375 6.56004 11.1989 6.64807 11.9999ZM9.33268 13.3333H6.66602V13.9999H9.33268V13.3333ZM3.83532 9.99939C3.10365 9.08639 2.66602 7.92759 2.66602 6.66659C2.66602 3.72107 5.05383 1.33325 7.99935 1.33325C10.9449 1.33325 13.3327 3.72107 13.3327 6.66659C13.3327 7.92825 12.8945 9.08759 12.1622 10.0009C11.7487 10.5165 10.666 11.3333 10.666 12.3333V13.9999C10.666 14.7363 10.0691 15.3333 9.33268 15.3333H6.66602C5.92964 15.3333 5.33268 14.7363 5.33268 13.9999V12.3333C5.33268 11.3333 4.24907 10.5157 3.83532 9.99939ZM8.66602 6.66979H10.3327L7.33268 10.6698V8.00312H5.66602L8.66602 3.99992V6.66979Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.EDUCATION: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M14 2.66683H4.66667C3.93029 2.66683 3.33333 3.26378 3.33333 4.00016C3.33333 4.73654 3.93029 5.3335 4.66667 5.3335H14V14.0002C14 14.3684 13.7015 14.6668 13.3333 14.6668H4.66667C3.19391 14.6668 2 13.4729 2 12.0002V4.00016C2 2.5274 3.19391 1.3335 4.66667 1.3335H13.3333C13.7015 1.3335 14 1.63198 14 2.00016V2.66683ZM3.33333 12.0002C3.33333 12.7366 3.93029 13.3335 4.66667 13.3335H12.6667V6.66683H4.66667C4.18095 6.66683 3.72557 6.53697 3.33333 6.31008V12.0002ZM13.3333 4.66683H4.66667C4.29848 4.66683 4 4.36835 4 4.00016C4 3.63198 4.29848 3.3335 4.66667 3.3335H13.3333V4.66683Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.BUSINESS: """<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 14 14" fill="none">
<path d="M3.66732 3.33341V1.33341C3.66732 0.965228 3.9658 0.666748 4.33398 0.666748H9.66732C10.0355 0.666748 10.334 0.965228 10.334 1.33341V3.33341H13.0007C13.3689 3.33341 13.6673 3.63189 13.6673 4.00008V13.3334C13.6673 13.7016 13.3689 14.0001 13.0007 14.0001H1.00065C0.632464 14.0001 0.333984 13.7016 0.333984 13.3334V4.00008C0.333984 3.63189 0.632464 3.33341 1.00065 3.33341H3.66732ZM12.334 8.66675H1.66732V12.6667H12.334V8.66675ZM12.334 4.66675H1.66732V7.33341H3.66732V6.00008H5.00065V7.33341H9.00065V6.00008H10.334V7.33341H12.334V4.66675ZM5.00065 2.00008V3.33341H9.00065V2.00008H5.00065Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.ENTERTAINMENT: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M11.3327 2.66675C13.5418 2.66675 15.3327 4.45761 15.3327 6.66675V9.33342C15.3327 11.5425 13.5418 13.3334 11.3327 13.3334H4.66602C2.45688 13.3334 0.666016 11.5425 0.666016 9.33342V6.66675C0.666016 4.45761 2.45688 2.66675 4.66602 2.66675H11.3327ZM11.3327 4.00008H4.66602C3.23788 4.00008 2.07196 5.12273 2.00262 6.53365L1.99935 6.66675V9.33342C1.99935 10.7615 3.122 11.9275 4.53292 11.9968L4.66602 12.0001H11.3327C12.7608 12.0001 13.9267 10.8774 13.9961 9.46648L13.9993 9.33342V6.66675C13.9993 5.23861 12.8767 4.07269 11.4657 4.00335L11.3327 4.00008ZM6.66602 6.00008V7.33342H7.99935V8.66675H6.66535L6.66602 10.0001H5.33268L5.33202 8.66675H3.99935V7.33342H5.33268V6.00008H6.66602ZM11.9993 8.66675V10.0001H10.666V8.66675H11.9993ZM10.666 6.00008V7.33342H9.33268V6.00008H10.666Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.UTILITIES: """<svg xmlns="http://www.w3.org/2000/svg" width="13" height="15" viewBox="0 0 13 15" fill="none">
<path d="M12.3346 0.333252C12.7028 0.333252 13.0013 0.631732 13.0013 0.999919V4.33325C13.0013 4.70144 12.7028 4.99992 12.3346 4.99992H9.0013V13.6666C9.0013 14.0348 8.70284 14.3333 8.33463 14.3333H5.66797C5.29978 14.3333 5.0013 14.0348 5.0013 13.6666V4.99992H1.33464C0.966449 4.99992 0.667969 4.70144 0.667969 4.33325V2.74527C0.667969 2.49276 0.810635 2.26192 1.0365 2.14899L4.66797 0.333252H12.3346ZM9.0013 1.66659H4.98273L2.0013 3.1573V3.66659H6.33464V12.9999H7.66797V3.66659H9.0013V1.66659ZM11.668 1.66659H10.3346V3.66659H11.668V1.66659Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.OTHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.00052 0.666748L4.00065 7.33342H12.0007L8.00052 0.666748ZM8.00052 3.25828L9.64572 6.00008H6.35553L8.00052 3.25828ZM4.50065 13.3334C3.48813 13.3334 2.66732 12.5126 2.66732 11.5001C2.66732 10.4875 3.48813 9.66675 4.50065 9.66675C5.51317 9.66675 6.33398 10.4875 6.33398 11.5001C6.33398 12.5126 5.51317 13.3334 4.50065 13.3334ZM4.50065 14.6667C6.24955 14.6667 7.66732 13.249 7.66732 11.5001C7.66732 9.75115 6.24955 8.33342 4.50065 8.33342C2.75175 8.33342 1.33398 9.75115 1.33398 11.5001C1.33398 13.249 2.75175 14.6667 4.50065 14.6667ZM10.0007 10.3334V13.0001H12.6673V10.3334H10.0007ZM8.66732 14.3334V9.00008H14.0007V14.3334H8.66732Z" fill="#344054"/>
</svg>""", # noqa: E501
ToolLabelEnum.RAG: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.00065 1.3335H9.33398V2.66683H8.00065V1.3335ZM5.33398 1.3335H6.66732V2.66683H5.33398V1.3335ZM3.99935 2.66683C3.99935 2.29864 4.29783 2.00016 4.66602 2.00016H12.3327C12.7009 2.00016 13.0007 2.29864 13.0007 2.66683V13.3335C13.0007 13.7017 12.7009 14.0002 12.3327 14.0002H4.66602C4.29783 14.0002 3.99935 13.7017 3.99935 13.3335V2.66683ZM4.66602 12.6668C4.29783 12.6668 3.99935 12.3683 3.99935 12.0002V10.6668H5.33398V12.0002C5.33398 12.3683 5.0355 12.6668 4.66602 12.6668ZM5.33398 8.66683H6.66732V10.0002H5.33398V8.66683ZM5.33398 6.66683H6.66732V8.00016H5.33398V6.66683ZM3.99935 4.66683H6.66602V6.00016H3.99935V4.66683ZM6.66602 1.3335H12.3327V2.66683H6.66602V1.3335Z" fill="#344054"/>
</svg>""", # noqa: E501
}
default_tool_label_dict = {
ToolLabelEnum.SEARCH: ToolLabel(
name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH]
),
ToolLabelEnum.IMAGE: ToolLabel(
name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE]
),
ToolLabelEnum.VIDEOS: ToolLabel(
name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS]
),
ToolLabelEnum.WEATHER: ToolLabel(
name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER]
),
ToolLabelEnum.FINANCE: ToolLabel(
name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE]
),
ToolLabelEnum.DESIGN: ToolLabel(
name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN]
),
ToolLabelEnum.TRAVEL: ToolLabel(
name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL]
),
ToolLabelEnum.SOCIAL: ToolLabel(
name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL]
),
ToolLabelEnum.NEWS: ToolLabel(
name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS]
),
ToolLabelEnum.MEDICAL: ToolLabel(
name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL]
),
ToolLabelEnum.PRODUCTIVITY: ToolLabel(
name="productivity",
label=I18nObject(en_US="Productivity", zh_Hans="生产力"),
icon=ICONS[ToolLabelEnum.PRODUCTIVITY],
),
ToolLabelEnum.EDUCATION: ToolLabel(
name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION]
),
ToolLabelEnum.BUSINESS: ToolLabel(
name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS]
),
ToolLabelEnum.ENTERTAINMENT: ToolLabel(
name="entertainment",
label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"),
icon=ICONS[ToolLabelEnum.ENTERTAINMENT],
),
ToolLabelEnum.UTILITIES: ToolLabel(
name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES]
),
ToolLabelEnum.OTHER: ToolLabel(
name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER]
),
ToolLabelEnum.RAG: ToolLabel(
name="rag", label=I18nObject(en_US="RAG", zh_Hans="RAG"), icon=ICONS[ToolLabelEnum.RAG]
),
}
default_tool_labels = list(default_tool_label_dict.values())
default_tool_label_name_list = [label.name for label in default_tool_labels]

View File

@@ -0,0 +1,41 @@
from core.tools.entities.tool_entities import ToolInvokeMeta
class ToolProviderNotFoundError(ValueError):
pass
class ToolNotFoundError(ValueError):
pass
class ToolParameterValidationError(ValueError):
pass
class ToolProviderCredentialValidationError(ValueError):
pass
class ToolNotSupportedError(ValueError):
pass
class ToolInvokeError(ValueError):
pass
class ToolApiSchemaError(ValueError):
pass
class ToolCredentialPolicyViolationError(ValueError):
pass
class ToolEngineInvokeError(Exception):
meta: ToolInvokeMeta
def __init__(self, meta, **kwargs):
self.meta = meta
super().__init__(**kwargs)

View File

@@ -0,0 +1,156 @@
from typing import Any, Self
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolDescription,
ToolEntity,
ToolIdentity,
ToolProviderEntityWithPlugin,
ToolProviderIdentity,
ToolProviderType,
)
from core.tools.mcp_tool.tool import MCPTool
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
class MCPToolProviderController(ToolProviderController):
def __init__(
self,
entity: ToolProviderEntityWithPlugin,
provider_id: str,
tenant_id: str,
server_url: str,
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
):
super().__init__(entity)
self.entity: ToolProviderEntityWithPlugin = entity
self.tenant_id = tenant_id
self.provider_id = provider_id
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
@property
def provider_type(self) -> ToolProviderType:
"""
returns the type of the provider
:return: type of the provider
"""
return ToolProviderType.MCP
@classmethod
def from_db(cls, db_provider: MCPToolProvider) -> Self:
"""
from db provider
"""
# Convert to entity first
provider_entity = db_provider.to_entity()
return cls.from_entity(provider_entity)
@classmethod
def from_entity(cls, entity: MCPProviderEntity) -> Self:
"""
create a MCPToolProviderController from a MCPProviderEntity
"""
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
tools = [
ToolEntity(
identity=ToolIdentity(
author="Anonymous", # Tool level author is not stored
name=remote_mcp_tool.name,
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
provider=entity.provider_id,
icon=entity.icon if isinstance(entity.icon, str) else "",
),
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
description=ToolDescription(
human=I18nObject(
en_US=remote_mcp_tool.description or "", zh_Hans=remote_mcp_tool.description or ""
),
llm=remote_mcp_tool.description or "",
),
output_schema=remote_mcp_tool.outputSchema or {},
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
)
for remote_mcp_tool in remote_mcp_tools
]
if not entity.icon:
raise ValueError("Database provider icon is required")
return cls(
entity=ToolProviderEntityWithPlugin(
identity=ToolProviderIdentity(
author="Anonymous", # Provider level author is not stored in entity
name=entity.name,
label=I18nObject(en_US=entity.name, zh_Hans=entity.name),
description=I18nObject(en_US="", zh_Hans=""),
icon=entity.icon if isinstance(entity.icon, str) else "",
),
plugin_id=None,
credentials_schema=[],
tools=tools,
),
provider_id=entity.provider_id,
tenant_id=entity.tenant_id,
server_url=entity.server_url,
headers=entity.headers,
timeout=entity.timeout,
sse_read_timeout=entity.sse_read_timeout,
)
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
"""
validate the credentials of the provider
"""
pass
def get_tool(self, tool_name: str) -> MCPTool:
"""
return tool with given name
"""
tool_entity = next(
(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
)
if not tool_entity:
raise ValueError(f"Tool with name {tool_name} not found")
return MCPTool(
entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
server_url=self.server_url,
provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
def get_tools(self) -> list[MCPTool]:
"""
get all tools
"""
return [
MCPTool(
entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
server_url=self.server_url,
provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
for tool_entity in self.entity.tools
]

View File

@@ -0,0 +1,174 @@
import base64
import json
import logging
from collections.abc import Generator
from typing import Any
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError
logger = logging.getLogger(__name__)
class MCPTool(Tool):
def __init__(
self,
entity: ToolEntity,
runtime: ToolRuntime,
tenant_id: str,
icon: str,
server_url: str,
provider_id: str,
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
):
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.server_url = server_url
self.provider_id = provider_id
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.MCP
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
result = self.invoke_remote_mcp_tool(tool_parameters)
# handle dify tool output
for content in result.content:
if isinstance(content, TextContent):
yield from self._process_text_content(content)
elif isinstance(content, ImageContent):
yield self._process_image_content(content)
elif isinstance(content, AudioContent):
yield self._process_audio_content(content)
else:
logger.warning("Unsupported content type=%s", type(content))
# handle MCP structured output
if self.entity.output_schema and result.structuredContent:
for k, v in result.structuredContent.items():
yield self.create_variable_message(k, v)
def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]:
"""Process text content and yield appropriate messages."""
# Check if content looks like JSON before attempting to parse
text = content.text.strip()
if text and text[0] in ("{", "[") and text[-1] in ("}", "]"):
try:
content_json = json.loads(text)
yield from self._process_json_content(content_json)
return
except json.JSONDecodeError:
pass
# If not JSON or parsing failed, treat as plain text
yield self.create_text_message(content.text)
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
"""Process JSON content based on its type."""
if isinstance(content_json, dict):
yield self.create_json_message(content_json)
elif isinstance(content_json, list):
yield from self._process_json_list(content_json)
else:
# For primitive types (str, int, bool, etc.), convert to string
yield self.create_text_message(str(content_json))
def _process_json_list(self, json_list: list) -> Generator[ToolInvokeMessage, None, None]:
"""Process a list of JSON items."""
if any(not isinstance(item, dict) for item in json_list):
# If the list contains any non-dict item, treat the entire list as a text message.
yield self.create_text_message(str(json_list))
return
# Otherwise, process each dictionary as a separate JSON message.
for item in json_list:
yield self.create_json_message(item)
def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage:
"""Process image content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
"""Process audio content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
return MCPTool(
entity=self.entity,
runtime=runtime,
tenant_id=self.tenant_id,
icon=self.icon,
server_url=self.server_url,
provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
"""
in mcp tool invoke, if the parameter is empty, it will be set to None
"""
return {
key: value
for key, value in parameter.items()
if value is not None and not (isinstance(value, str) and value.strip() == "")
}
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
headers = self.headers.copy() if self.headers else {}
tool_parameters = self._handle_none_parameter(tool_parameters)
from sqlalchemy.orm import Session
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import MCPToolManageService
# Step 1: Load provider entity and credentials in a short-lived session
# This minimizes database connection hold time
with Session(db.engine, expire_on_commit=False) as session:
mcp_service = MCPToolManageService(session=session)
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
# Decrypt and prepare all credentials before closing session
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_headers()
# Try to get existing token and add to headers
if not headers:
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
# Step 2: Session is now closed, perform network operations without holding database connection
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
try:
with MCPClientWithAuthRetry(
server_url=server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e

View File

@@ -0,0 +1,79 @@
from typing import Any
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.plugin_tool.tool import PluginTool
class PluginToolProviderController(BuiltinToolProviderController):
entity: ToolProviderEntityWithPlugin
tenant_id: str
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
):
self.entity = entity
self.tenant_id = tenant_id
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> ToolProviderType:
"""
returns the type of the provider
:return: type of the provider
"""
return ToolProviderType.PLUGIN
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
"""
validate the credentials of the provider
"""
manager = PluginToolManager()
if not manager.validate_provider_credentials(
tenant_id=self.tenant_id,
user_id=user_id,
provider=self.entity.identity.name,
credentials=credentials,
):
raise ToolProviderCredentialValidationError("Invalid credentials")
def get_tool(self, tool_name: str) -> PluginTool: # type: ignore
"""
return tool with given name
"""
tool_entity = next(
(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
)
if not tool_entity:
raise ValueError(f"Tool with name {tool_name} not found")
return PluginTool(
entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)
def get_tools(self) -> list[PluginTool]: # type: ignore
"""
get all tools
"""
return [
PluginTool(
entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)
for tool_entity in self.entity.tools
]

View File

@@ -0,0 +1,85 @@
from collections.abc import Generator
from typing import Any
from core.plugin.impl.tool import PluginToolManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
class PluginTool(Tool):
def __init__(
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
):
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
self.runtime_parameters: list[ToolParameter] | None = None
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.PLUGIN
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
manager = PluginToolManager()
tool_parameters = convert_parameters_to_plugin_format(tool_parameters)
yield from manager.invoke(
tenant_id=self.tenant_id,
user_id=user_id,
tool_provider=self.entity.identity.provider,
tool_name=self.entity.identity.name,
credentials=self.runtime.credentials,
credential_type=self.runtime.credential_type,
tool_parameters=tool_parameters,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
return PluginTool(
entity=self.entity,
runtime=runtime,
tenant_id=self.tenant_id,
icon=self.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)
def get_runtime_parameters(
self,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> list[ToolParameter]:
"""
get the runtime parameters
"""
if not self.entity.has_runtime_parameters:
return self.entity.parameters
if self.runtime_parameters is not None:
return self.runtime_parameters
manager = PluginToolManager()
self.runtime_parameters = manager.get_runtime_parameters(
tenant_id=self.tenant_id,
user_id="",
provider=self.entity.identity.provider,
tool=self.entity.identity.name,
credentials=self.runtime.credentials,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
return self.runtime_parameters

View File

@@ -0,0 +1,42 @@
import base64
import hashlib
import hmac
import os
import time
from configs import dify_config
def sign_tool_file(tool_file_id: str, extension: str) -> str:
"""
sign file to get a temporary url for plugin access
"""
# Use internal URL for plugin/tool file access in Docker environments
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
"""
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT

View File

@@ -0,0 +1,365 @@
import contextlib
import json
from collections.abc import Generator, Iterable
from copy import deepcopy
from datetime import UTC, datetime
from mimetypes import guess_type
from typing import Any, Union, cast
from yarl import URL
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import FileType
from core.file.models import FileTransferMethod
from core.ops.ops_trace_manager import TraceQueueManager
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolInvokeMessage,
ToolInvokeMessageBinary,
ToolInvokeMeta,
ToolParameter,
)
from core.tools.errors import (
ToolEngineInvokeError,
ToolInvokeError,
ToolNotFoundError,
ToolNotSupportedError,
ToolParameterValidationError,
ToolProviderCredentialValidationError,
ToolProviderNotFoundError,
)
from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import Message, MessageFile
class ToolEngine:
"""
Tool runtime engine take care of the tool executions.
"""
@staticmethod
def agent_invoke(
tool: Tool,
tool_parameters: Union[str, dict],
user_id: str,
tenant_id: str,
message: Message,
invoke_from: InvokeFrom,
agent_tool_callback: DifyAgentCallbackHandler,
trace_manager: TraceQueueManager | None = None,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> tuple[str, list[str], ToolInvokeMeta]:
"""
Agent invokes the tool with the given arguments.
"""
# check if arguments is a string
if isinstance(tool_parameters, str):
# check if this tool has only one parameter
parameters = [
parameter
for parameter in tool.get_runtime_parameters()
if parameter.form == ToolParameter.ToolParameterForm.LLM
]
if parameters and len(parameters) == 1:
tool_parameters = {parameters[0].name: tool_parameters}
else:
with contextlib.suppress(Exception):
tool_parameters = json.loads(tool_parameters)
if not isinstance(tool_parameters, dict):
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
try:
# hit the callback handler
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id)
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
def message_callback(
invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]
):
for message in messages:
if isinstance(message, ToolInvokeMeta):
invocation_meta_dict["meta"] = message
else:
yield message
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=message_callback(invocation_meta_dict, messages),
user_id=user_id,
tenant_id=tenant_id,
conversation_id=message.conversation_id,
)
message_list = list(messages)
# extract binary data from tool invoke message
binary_files = ToolEngine._extract_tool_response_binary_and_text(message_list)
# create message file
message_files = ToolEngine._create_message_files(
tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id
)
plain_text = ToolEngine._convert_tool_response_to_str(message_list)
meta = invocation_meta_dict["meta"]
# hit the callback handler
agent_tool_callback.on_tool_end(
tool_name=tool.entity.identity.name,
tool_inputs=tool_parameters,
tool_outputs=plain_text,
message_id=message.id,
trace_manager=trace_manager,
)
# transform tool invoke message to get LLM friendly message
return plain_text, message_files, meta
except ToolProviderCredentialValidationError as e:
error_response = "Please check your tool provider credentials"
agent_tool_callback.on_tool_error(e)
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
error_response = f"there is not a tool named {tool.entity.identity.name}"
agent_tool_callback.on_tool_error(e)
except ToolParameterValidationError as e:
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
agent_tool_callback.on_tool_error(e)
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
agent_tool_callback.on_tool_error(e)
except ToolEngineInvokeError as e:
meta = e.meta
error_response = f"tool invoke error: {meta.error}"
agent_tool_callback.on_tool_error(e)
return error_response, [], meta
except Exception as e:
error_response = f"unknown error: {e}"
agent_tool_callback.on_tool_error(e)
return error_response, [], ToolInvokeMeta.error_instance(error_response)
@staticmethod
def generic_invoke(
tool: Tool,
tool_parameters: dict[str, Any],
user_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler,
workflow_call_depth: int,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
Workflow invokes the tool with the given arguments.
"""
try:
# hit the callback handler
workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
if isinstance(tool, WorkflowTool):
tool.workflow_call_depth = workflow_call_depth + 1
if tool.runtime and tool.runtime.runtime_parameters:
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
response = tool.invoke(
user_id=user_id,
tool_parameters=tool_parameters,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
# hit the callback handler
response = workflow_tool_callback.on_tool_execution(
tool_name=tool.entity.identity.name,
tool_inputs=tool_parameters,
tool_outputs=response,
)
return response
except Exception as e:
workflow_tool_callback.on_tool_error(e)
raise e
@staticmethod
def _invoke(
tool: Tool,
tool_parameters: dict,
user_id: str,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
"""
Invoke the tool with the given arguments.
"""
started_at = datetime.now(UTC)
meta = ToolInvokeMeta(
time_cost=0.0,
error=None,
tool_config={
"tool_name": tool.entity.identity.name,
"tool_provider": tool.entity.identity.provider,
"tool_provider_type": tool.tool_provider_type().value,
"tool_parameters": deepcopy(tool.runtime.runtime_parameters),
"tool_icon": tool.entity.identity.icon,
},
)
try:
yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id)
except Exception as e:
meta.error = str(e)
raise ToolEngineInvokeError(meta)
finally:
ended_at = datetime.now(UTC)
meta.time_cost = (ended_at - started_at).total_seconds()
yield meta
@staticmethod
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
"""
Handle tool response
"""
parts: list[str] = []
json_parts: list[str] = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
parts.append(cast(ToolInvokeMessage.TextMessage, response.message).text)
elif response.type == ToolInvokeMessage.MessageType.LINK:
parts.append(
f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
+ " please tell user to check it."
)
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
parts.append(
"image has been created and sent to user already, "
+ "you do not need to create it, just tell the user to check it now."
)
elif response.type == ToolInvokeMessage.MessageType.JSON:
json_message = cast(ToolInvokeMessage.JsonMessage, response.message)
if json_message.suppress_output:
continue
json_parts.append(
json.dumps(
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
ensure_ascii=False,
)
)
else:
parts.append(str(response.message))
# Add JSON parts, avoiding duplicates from text parts.
if json_parts:
existing_parts = set(parts)
parts.extend(p for p in json_parts if p not in existing_parts)
return "".join(parts)
@staticmethod
def _extract_tool_response_binary_and_text(
tool_response: list[ToolInvokeMessage],
) -> Generator[ToolInvokeMessageBinary, None, None]:
"""
Extract tool response binary
"""
for response in tool_response:
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
mimetype = None
if not response.meta:
raise ValueError("missing meta data")
if response.meta.get("mime_type"):
mimetype = response.meta.get("mime_type")
else:
with contextlib.suppress(Exception):
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
extension = url.suffix
guess_type_result, _ = guess_type(f"a{extension}")
if guess_type_result:
mimetype = guess_type_result
if not mimetype:
mimetype = "image/jpeg"
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", mimetype),
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
if not response.meta:
raise ValueError("missing meta data")
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "application/octet-stream"),
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
)
elif response.type == ToolInvokeMessage.MessageType.LINK:
# check if there is a mime type in meta
if response.meta and "mime_type" in response.meta:
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "application/octet-stream")
if response.meta
else "application/octet-stream",
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
)
@staticmethod
def _create_message_files(
tool_messages: Iterable[ToolInvokeMessageBinary],
agent_message: Message,
invoke_from: InvokeFrom,
user_id: str,
) -> list[str]:
"""
Create message file
:return: message file ids
"""
result = []
for message in tool_messages:
if "image" in message.mimetype:
file_type = FileType.IMAGE
elif "video" in message.mimetype:
file_type = FileType.VIDEO
elif "audio" in message.mimetype:
file_type = FileType.AUDIO
elif "text" in message.mimetype or "pdf" in message.mimetype:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
# extract tool file id from url
tool_file_id = message.url.split("/")[-1].split(".")[0]
message_file = MessageFile(
message_id=agent_message.id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
belongs_to="assistant",
url=message.url,
upload_file_id=tool_file_id,
created_by_role=(
CreatorUserRole.ACCOUNT
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatorUserRole.END_USER
),
created_by=user_id,
)
db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)
result.append(message_file.id)
db.session.close()
return result

View File

@@ -0,0 +1,253 @@
import base64
import hashlib
import hmac
import logging
import os
import time
from collections.abc import Generator
from mimetypes import guess_extension, guess_type
from typing import Union
from uuid import uuid4
import httpx
from sqlalchemy.orm import Session
from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db as global_db
from extensions.ext_storage import storage
from models.model import MessageFile
from models.tools import ToolFile
logger = logging.getLogger(__name__)
from sqlalchemy.engine import Engine
class ToolFileManager:
_engine: Engine
def __init__(self, engine: Engine | None = None):
if engine is None:
engine = global_db.engine
self._engine = engine
@staticmethod
def sign_file(tool_file_id: str, extension: str) -> str:
"""
sign file to get a temporary url for plugin access
"""
# Use internal URL for plugin/tool file access in Docker environments
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@staticmethod
def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
"""
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
def create_file_by_raw(
self,
*,
user_id: str,
tenant_id: str,
conversation_id: str | None,
file_binary: bytes,
mimetype: str,
filename: str | None = None,
) -> ToolFile:
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
unique_filename = f"{unique_name}{extension}"
# default just as before
present_filename = unique_filename
if filename is not None:
has_extension = len(filename.split(".")) > 1
# Add extension flexibly
present_filename = filename if has_extension else f"{filename}{extension}"
filepath = f"tools/{tenant_id}/{unique_filename}"
storage.save(filepath, file_binary)
with Session(self._engine, expire_on_commit=False) as session:
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_key=filepath,
mimetype=mimetype,
name=present_filename,
size=len(file_binary),
original_url=None,
)
session.add(tool_file)
session.commit()
session.refresh(tool_file)
return tool_file
def create_file_by_url(
self,
user_id: str,
tenant_id: str,
file_url: str,
conversation_id: str | None = None,
) -> ToolFile:
# try to download image
try:
response = ssrf_proxy.get(file_url)
response.raise_for_status()
blob = response.content
except httpx.TimeoutException:
raise ValueError(f"timeout when downloading file from {file_url}")
mimetype = (
guess_type(file_url)[0]
or response.headers.get("Content-Type", "").split(";")[0].strip()
or "application/octet-stream"
)
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
with Session(self._engine, expire_on_commit=False) as session:
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_key=filepath,
mimetype=mimetype,
original_url=file_url,
name=filename,
size=len(blob),
)
session.add(tool_file)
session.commit()
return tool_file
def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(
ToolFile.id == id,
)
.first()
)
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
return blob, tool_file.mimetype
def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
with Session(self._engine, expire_on_commit=False) as session:
message_file: MessageFile | None = (
session.query(MessageFile)
.where(
MessageFile.id == id,
)
.first()
)
# Check if message_file is not None
if message_file is not None:
# get tool file id
if message_file.url is not None:
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
else:
tool_file_id = None
else:
tool_file_id = None
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(
ToolFile.id == tool_file_id,
)
.first()
)
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
return blob, tool_file.mimetype
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]:
"""
get file binary
:param tool_file_id: the id of the tool file
:return: the binary of the file, mime type
"""
with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(
ToolFile.id == tool_file_id,
)
.first()
)
if not tool_file:
return None, None
stream = storage.load_stream(tool_file.file_key)
return stream, tool_file
# init tool_file_parser
from core.file.tool_file_parser import set_tool_file_manager_factory
def _factory() -> ToolFileManager:
return ToolFileManager()
set_tool_file_manager_factory(_factory)

View File

@@ -0,0 +1,97 @@
from sqlalchemy import select
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.values import default_tool_label_name_list
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from extensions.ext_database import db
from models.tools import ToolLabelBinding
class ToolLabelManager:
@classmethod
def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]:
"""
Filter tool labels
"""
tool_labels = [label for label in tool_labels if label in default_tool_label_name_list]
return list(set(tool_labels))
@classmethod
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
"""
Update tool labels
"""
labels = cls.filter_tool_labels(labels)
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id
else:
raise ValueError("Unsupported tool type")
# delete old labels
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()
# insert new labels
for label in labels:
db.session.add(
ToolLabelBinding(
tool_id=provider_id,
tool_type=controller.provider_type.value,
label_name=label,
)
)
db.session.commit()
@classmethod
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
"""
Get tool labels
"""
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id
elif isinstance(controller, BuiltinToolProviderController):
return controller.tool_labels
else:
raise ValueError("Unsupported tool type")
stmt = select(ToolLabelBinding.label_name).where(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
)
labels = db.session.scalars(stmt).all()
return list(labels)
@classmethod
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
"""
Get tools labels
:param tool_providers: list of tool providers
:return: dict of tool labels
:key: tool id
:value: list of tool labels
"""
if not tool_providers:
return {}
for controller in tool_providers:
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
raise ValueError("Unsupported tool type")
provider_ids = []
for controller in tool_providers:
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id)
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
for label in labels:
tool_labels[label.tool_id].append(label.label_name)
return tool_labels

File diff suppressed because it is too large Load Diff

View File

View File

@@ -0,0 +1,158 @@
import contextlib
from copy import deepcopy
from typing import Any
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
ToolProviderType,
)
class ToolParameterConfigurationManager:
"""
Tool parameter configuration manager
"""
tenant_id: str
tool_runtime: Tool
provider_name: str
provider_type: ToolProviderType
identity_id: str
def __init__(
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
):
self.tenant_id = tenant_id
self.tool_runtime = tool_runtime
self.provider_name = provider_name
self.provider_type = provider_type
self.identity_id = identity_id
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
deep copy parameters
"""
return deepcopy(parameters)
def _merge_parameters(self) -> list[ToolParameter]:
"""
merge parameters
"""
# get tool parameters
tool_parameters = self.tool_runtime.entity.parameters or []
# get tool runtime parameters
runtime_parameters = self.tool_runtime.get_runtime_parameters()
# override parameters
current_parameters = tool_parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
return current_parameters
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
mask tool parameters
return a deep copy of parameters with masked values
"""
parameters = self._deep_copy(parameters)
# override parameters
current_parameters = self._merge_parameters()
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
if len(parameters[parameter.name]) > 6:
parameters[parameter.name] = (
parameters[parameter.name][:2]
+ "*" * (len(parameters[parameter.name]) - 4)
+ parameters[parameter.name][-2:]
)
else:
parameters[parameter.name] = "*" * len(parameters[parameter.name])
return parameters
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
encrypt tool parameters with tenant id
return a deep copy of parameters with encrypted values
"""
# override parameters
current_parameters = self._merge_parameters()
parameters = self._deep_copy(parameters)
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
parameters[parameter.name] = encrypted
return parameters
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
decrypt tool parameters with tenant id
return a deep copy of parameters with decrypted values
"""
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type.value}.{self.provider_name}",
tool_name=self.tool_runtime.entity.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id,
)
cached_parameters = cache.get()
if cached_parameters:
return cached_parameters
# override parameters
current_parameters = self._merge_parameters()
has_secret_input = False
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
has_secret_input = True
with contextlib.suppress(Exception):
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
if has_secret_input:
cache.set(parameters)
return parameters
def delete_tool_parameters_cache(self):
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type.value}.{self.provider_name}",
tool_name=self.tool_runtime.entity.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id,
)
cache.delete()

View File

@@ -0,0 +1,200 @@
import threading
from typing import Any
from flask import Flask, current_app
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
class DatasetMultiRetrieverToolInput(BaseModel):
query: str = Field(..., description="dataset multi retriever and rerank")
class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying multi dataset."""
name: str = "dataset_"
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
description: str = "dataset multi retriever and rerank. "
dataset_ids: list[str]
reranking_provider_name: str
reranking_model_name: str
@classmethod
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
return cls(
name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs
)
def _run(self, query: str) -> str:
threads = []
all_documents: list[RagDocument] = []
for dataset_id in self.dataset_ids:
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"all_documents": all_documents,
"hit_callbacks": self.hit_callbacks,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
# do rerank for searched documents
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=self.reranking_provider_name,
model_type=ModelType.RERANK,
model=self.reranking_model_name,
)
rerank_runner = RerankModelRunner(rerank_model_instance)
all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(all_documents)
document_score_list = {}
for item in all_documents:
if item.metadata and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
segments = db.session.scalars(document_segment_stmt).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
else:
document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list: list[RetrievalSourceMetadata] = []
resource_number = 1
for segment in sorted_segments:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document_stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
document = db.session.scalar(document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
position=resource_number,
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=self.retriever_from,
score=document_score_list.get(segment.index_node_id),
doc_metadata=document.doc_metadata,
)
if self.retriever_from == "dev":
source.hit_count = segment.hit_count
source.word_count = segment.word_count
source.segment_position = segment.position
source.index_node_hash = segment.index_node_hash
if segment.answer:
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source.content = segment.content
context_list.append(source)
resource_number += 1
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
return ""
def _retriever(
self,
flask_app: Flask,
dataset_id: str,
query: str,
all_documents: list,
hit_callbacks: list[DatasetIndexToolCallbackHandler],
):
with flask_app.app_context():
stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
return []
for hit_callback in hit_callbacks:
hit_callback.on_query(query, dataset.id)
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 4,
)
if documents:
all_documents.extend(documents)
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 4,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)
all_documents.extend(documents)

View File

@@ -0,0 +1,31 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel, ConfigDict
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
class DatasetRetrieverBaseTool(BaseModel, ABC):
"""Tool for querying a Dataset."""
name: str = "dataset"
description: str = "use this to retrieve a dataset. "
tenant_id: str
top_k: int = 4
score_threshold: float | None = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str
model_config = ConfigDict(arbitrary_types_allowed=True)
def run(self, query: str) -> str:
"""Use the tool."""
return self._run(query)
@abstractmethod
def _run(self, query: str) -> str:
"""Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None
to child implementations to enable tracing,
"""

View File

@@ -0,0 +1,234 @@
from typing import Any, cast
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"reranking_mode": "reranking_model",
"top_k": 2,
"score_threshold_enabled": False,
}
class DatasetRetrieverToolInput(BaseModel):
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
class DatasetRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying a Dataset."""
name: str = "dataset"
args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
dataset_id: str
user_id: str | None = None
retrieve_config: DatasetRetrieveConfigEntity
inputs: dict
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description
if not description:
description = "useful for when you want to answer queries about the " + dataset.name
description = description.replace("\n", "").replace("\r", "")
return cls(
name=f"dataset_{dataset.id.replace('-', '_')}",
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
description=description,
**kwargs,
)
def _run(self, query: str) -> str:
dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id)
dataset = db.session.scalar(dataset_stmt)
if not dataset:
return ""
for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id)
dataset_retrieval = DatasetRetrieval()
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
[dataset.id],
query,
self.tenant_id,
self.user_id or "unknown",
cast(str, self.retrieve_config.metadata_filtering_mode),
cast(ModelConfig, self.retrieve_config.metadata_model_config),
self.retrieve_config.metadata_filtering_conditions,
self.inputs,
)
if metadata_filter_document_ids:
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
else:
document_ids_filter = None
if dataset.provider == "external":
results: list[RetrievalDocument] = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
metadata_condition=metadata_condition,
)
for external_document in external_documents:
document = RetrievalDocument(
page_content=external_document.get("content"),
metadata=external_document.get("metadata"),
provider="external",
)
if document.metadata is not None:
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset.id
document.metadata["dataset_name"] = dataset.name
results.append(document)
# deal with external documents
context_list: list[RetrievalSourceMetadata] = []
for position, item in enumerate(results, start=1):
if item.metadata is not None:
source = RetrievalSourceMetadata(
position=position,
dataset_id=item.metadata.get("dataset_id"),
dataset_name=item.metadata.get("dataset_name"),
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
document_name=item.metadata.get("title"),
data_source_type="external",
retriever_from=self.retriever_from,
score=item.metadata.get("score"),
title=item.metadata.get("title"),
content=item.page_content,
)
context_list.append(source)
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join([item.page_content for item in results]))
else:
if metadata_condition and not document_ids_filter:
return ""
# get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
retrieval_resource_list: list[RetrievalSourceMetadata] = []
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
document_ids_filter=document_ids_filter,
)
return str("\n".join([document.page_content for document in documents]))
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model")
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights"),
document_ids_filter=document_ids_filter,
)
else:
documents = []
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
if item.metadata is not None and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list: list[DocumentContext] = []
records = RetrievalService.format_retrieval_documents(documents)
if records:
for record in records:
segment = record.segment
if segment.answer:
document_context_list.append(
DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=record.score,
)
)
else:
document_context_list.append(
DocumentContext(
content=segment.get_sign_content(),
score=record.score,
)
)
if self.return_resource:
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=self.retriever_from,
score=record.score or 0.0,
doc_metadata=document.doc_metadata,
)
if self.retriever_from == "dev":
source.hit_count = segment.hit_count
source.word_count = segment.word_count
source.segment_position = segment.position
source.index_node_hash = segment.index_node_hash
if segment.answer:
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source.content = segment.content
retrieval_resource_list.append(source)
if self.return_resource and retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list,
key=lambda x: x.score or 0.0,
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
item.position = position # type: ignore
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
return str("\n".join([document_context.content for document_context in document_context_list]))
return ""

View File

@@ -0,0 +1,136 @@
from collections.abc import Generator
from typing import Any
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolDescription,
ToolEntity,
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
class DatasetRetrieverTool(Tool):
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool):
super().__init__(entity, runtime)
self.retrieval_tool = retrieval_tool
@staticmethod
def get_dataset_tools(
tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity | None,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: dict,
) -> list["DatasetRetrieverTool"]:
"""
get dataset tool
"""
# check if retrieve_config is valid
if dataset_ids is None or len(dataset_ids) == 0:
return []
if retrieve_config is None:
return []
feature = DatasetRetrieval()
# save original retrieve strategy, and set retrieve strategy to SINGLE
# Agent only support SINGLE mode
original_retriever_mode = retrieve_config.retrieve_strategy
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
retrieval_tools = feature.to_dataset_retriever_tool(
tenant_id=tenant_id,
dataset_ids=dataset_ids,
retrieve_config=retrieve_config,
return_resource=return_resource,
invoke_from=invoke_from,
hit_callback=hit_callback,
user_id=user_id,
inputs=inputs,
)
if retrieval_tools is None or len(retrieval_tools) == 0:
return []
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode
# convert retrieval tools to Tools
tools = []
for retrieval_tool in retrieval_tools:
tool = DatasetRetrieverTool(
retrieval_tool=retrieval_tool,
entity=ToolEntity(
identity=ToolIdentity(
provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
),
parameters=[],
description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
),
runtime=ToolRuntime(tenant_id=tenant_id),
)
tools.append(tool)
return tools
def get_runtime_parameters(
self,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> list[ToolParameter]:
return [
ToolParameter(
name="query",
label=I18nObject(en_US="", zh_Hans=""),
human_description=I18nObject(en_US="", zh_Hans=""),
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
llm_description="Query for the dataset to be used to retrieve the dataset.",
required=True,
default="",
placeholder=I18nObject(en_US="", zh_Hans=""),
),
]
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.DATASET_RETRIEVAL
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke dataset retriever tool
"""
query = tool_parameters.get("query")
if not query:
yield self.create_text_message(text="please input query")
else:
# invoke dataset retriever tool
result = self.retrieval_tool.run(query=query)
yield self.create_text_message(text=result)
def validate_credentials(
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
) -> str | None:
"""
validate the credentials for dataset retriever tool
"""
pass

View File

@@ -0,0 +1,32 @@
# Import generic components from provider_encryption module
from core.helper.provider_encryption import (
ProviderConfigCache,
ProviderConfigEncrypter,
create_provider_encrypter,
)
# Re-export for backward compatibility
__all__ = [
"ProviderConfigCache",
"ProviderConfigEncrypter",
"create_provider_encrypter",
"create_tool_provider_encrypter",
]
# Tool-specific imports
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.tools.__base.tool_provider import ToolProviderController
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
cache = SingletonProviderCredentialsCache(
tenant_id=tenant_id,
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
)
encrypt = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_config_cache=cache,
)
return encrypt, cache

View File

@@ -0,0 +1,168 @@
import logging
from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
from mimetypes import guess_extension
from uuid import UUID
import numpy as np
import pytz
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
from libs.login import current_user
from models import Account
logger = logging.getLogger(__name__)
def safe_json_value(v):
if isinstance(v, datetime):
tz_name = "UTC"
if isinstance(current_user, Account) and current_user.timezone is not None:
tz_name = current_user.timezone
return v.astimezone(pytz.timezone(tz_name)).isoformat()
elif isinstance(v, date):
return v.isoformat()
elif isinstance(v, UUID):
return str(v)
elif isinstance(v, Decimal):
return float(v)
elif isinstance(v, bytes):
try:
return v.decode("utf-8")
except UnicodeDecodeError:
return v.hex()
elif isinstance(v, memoryview):
return v.tobytes().hex()
elif isinstance(v, np.ndarray):
return v.tolist()
elif isinstance(v, dict):
return safe_json_dict(v)
elif isinstance(v, list | tuple | set):
return [safe_json_value(i) for i in v]
else:
return v
def safe_json_dict(d: dict):
if not isinstance(d, dict):
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
return {k: safe_json_value(v) for k, v in d.items()}
class ToolFileMessageTransformer:
@classmethod
def transform_tool_invoke_messages(
cls,
messages: Generator[ToolInvokeMessage, None, None],
user_id: str,
tenant_id: str,
conversation_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
Transform tool message and handle file download
"""
for message in messages:
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
yield message
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(
message.message, ToolInvokeMessage.TextMessage
):
# try to download image
try:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
tool_file_manager = ToolFileManager()
tool_file = tool_file_manager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=message.message.text,
conversation_id=conversation_id,
)
url = f"/files/tools/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}"
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
meta=message.meta.copy() if message.meta is not None else {},
)
except Exception as e:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(
text=f"Failed to download image: {message.message.text}: {e}"
),
meta=message.meta.copy() if message.meta is not None else {},
)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage
meta = message.meta or {}
mimetype = meta.get("mime_type", "application/octet-stream")
# get filename from meta
filename = meta.get("filename", None)
# if message is str, encode it to bytes
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
raise ValueError("unexpected message type")
assert isinstance(message.message.blob, bytes)
tool_file_manager = ToolFileManager()
tool_file = tool_file_manager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_binary=message.message.blob,
mimetype=mimetype,
filename=filename,
)
url = cls.get_tool_file_url(tool_file_id=tool_file.id, extension=guess_extension(tool_file.mimetype))
# check if file is image
if "image" in mimetype:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BINARY_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
elif message.type == ToolInvokeMessage.MessageType.FILE:
meta = message.meta or {}
file = meta.get("file", None)
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=ToolInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield message
elif message.type == ToolInvokeMessage.MessageType.JSON:
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
message.message.json_object = safe_json_value(message.message.json_object)
yield message
else:
yield message
@classmethod
def get_tool_file_url(cls, tool_file_id: str, extension: str | None) -> str:
return f"/files/tools/{tool_file_id}{extension or '.bin'}"

View File

@@ -0,0 +1,167 @@
"""
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
Therefore, a model manager is needed to list/invoke/validate models.
"""
import json
from decimal import Decimal
from typing import cast
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.tools import ToolModelInvoke
class InvokeModelError(Exception):
pass
class ModelInvocationUtils:
@staticmethod
def get_max_llm_context_tokens(
tenant_id: str,
) -> int:
"""
get max llm context tokens of the model
"""
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
if not model_instance:
raise InvokeModelError("Model not found")
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if not schema:
raise InvokeModelError("No model schema found")
max_tokens: int | None = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
if max_tokens is None:
return 2048
return max_tokens
@staticmethod
def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
"""
calculate tokens from prompt messages and model parameters
"""
# get model instance
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
if not model_instance:
raise InvokeModelError("Model not found")
# get tokens
tokens = model_instance.get_llm_num_tokens(prompt_messages)
return tokens
@staticmethod
def invoke(
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
invoke model with parameters in user's own context
:param user_id: user id
:param tenant_id: tenant id, the tenant id of the creator of the tool
:param tool_type: tool type
:param tool_name: tool name
:param prompt_messages: prompt messages
:return: AssistantPromptMessage
"""
# get model manager
model_manager = ModelManager()
# get model instance
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
# get prompt tokens
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
model_parameters = {
"temperature": 0.8,
"top_p": 0.8,
}
# create tool model invoke
tool_model_invoke = ToolModelInvoke(
user_id=user_id,
tenant_id=tenant_id,
provider=model_instance.provider,
tool_type=tool_type,
tool_name=tool_name,
model_parameters=json.dumps(model_parameters),
prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
model_response="",
prompt_tokens=prompt_tokens,
answer_tokens=0,
answer_unit_price=Decimal(),
answer_price_unit=Decimal(),
provider_response_latency=0,
total_price=Decimal(),
currency="USD",
)
db.session.add(tool_model_invoke)
db.session.commit()
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
)
except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}")
except InvokeBadRequestError as e:
raise InvokeModelError(f"Invoke bad request error: {e}")
except InvokeConnectionError as e:
raise InvokeModelError(f"Invoke connection error: {e}")
except InvokeAuthorizationError as e:
raise InvokeModelError("Invoke authorization error")
except InvokeServerUnavailableError as e:
raise InvokeModelError(f"Invoke server unavailable error: {e}")
except Exception as e:
raise InvokeModelError(f"Invoke error: {e}")
# update tool model invoke
tool_model_invoke.model_response = str(response.message.content)
if response.usage:
tool_model_invoke.answer_tokens = response.usage.completion_tokens
tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
tool_model_invoke.provider_response_latency = response.usage.latency
tool_model_invoke.total_price = response.usage.total_price
tool_model_invoke.currency = response.usage.currency
db.session.commit()
return response

View File

@@ -0,0 +1,453 @@
import re
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
from typing import Any
import httpx
from flask import request
from yaml import YAMLError, safe_load
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
# set description to extra_info
extra_info["description"] = openapi["info"].get("description", "")
if len(openapi["servers"]) == 0:
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
server_url = openapi["servers"][0]["url"]
request_env = request.headers.get("X-Request-Env")
if request_env:
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
server_url = matched_servers[0] if matched_servers else server_url
# list all interfaces
interfaces = []
for path, path_item in openapi["paths"].items():
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
for method in methods:
if method in path_item:
interfaces.append(
{
"path": path,
"method": method,
"operation": path_item[method],
}
)
# get all parameters
bundles = []
for interface in interfaces:
# convert parameters
parameters = []
if "parameters" in interface["operation"]:
for i, parameter in enumerate(interface["operation"]["parameters"]):
if "$ref" in parameter:
root = openapi
reference = parameter["$ref"].split("/")[1:]
for ref in reference:
root = root[ref]
interface["operation"]["parameters"][i] = root
for parameter in interface["operation"]["parameters"]:
# Handle complex type defaults that are not supported by PluginParameter
default_value = None
if "schema" in parameter and "default" in parameter["schema"]:
default_value = ApiBasedToolSchemaParser._sanitize_default_value(parameter["schema"]["default"])
tool_parameter = ToolParameter(
name=parameter["name"],
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
human_description=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=parameter.get("required", False),
form=ToolParameter.ToolParameterForm.LLM,
llm_description=parameter.get("description"),
default=default_value,
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
if typ:
tool_parameter.type = typ
parameters.append(tool_parameter)
# create tool bundle
# check if there is a request body
if "requestBody" in interface["operation"]:
request_body = interface["operation"]["requestBody"]
if "content" in request_body:
for content_type, content in request_body["content"].items():
# if there is a reference, get the reference and overwrite the content
if "schema" not in content:
continue
if "$ref" in content["schema"]:
# get the reference
root = openapi
reference = content["schema"]["$ref"].split("/")[1:]
for ref in reference:
root = root[ref]
# overwrite the content
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
# handle allOf reference in schema properties
for prop_dict in root.get("properties", {}).values():
for item in prop_dict.get("allOf", []):
if "$ref" in item:
ref_schema = openapi
reference = item["$ref"].split("/")[1:]
for ref in reference:
ref_schema = ref_schema[ref]
else:
ref_schema = item
for key, value in ref_schema.items():
if isinstance(value, list):
if key not in prop_dict:
prop_dict[key] = []
# extends list field
if isinstance(prop_dict[key], list):
prop_dict[key].extend(value)
elif key not in prop_dict:
# add new field
prop_dict[key] = value
if "allOf" in prop_dict:
del prop_dict["allOf"]
# parse body parameters
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
required = body_schema.get("required", [])
properties = body_schema.get("properties", {})
for name, property in properties.items():
# Handle complex type defaults that are not supported by PluginParameter
default_value = ApiBasedToolSchemaParser._sanitize_default_value(
property.get("default", None)
)
tool = ToolParameter(
name=name,
label=I18nObject(en_US=name, zh_Hans=name),
human_description=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=name in required,
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=default_value,
placeholder=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
if typ:
tool.type = typ
parameters.append(tool)
# check if parameters is duplicated
parameters_count = {}
for parameter in parameters:
if parameter.name not in parameters_count:
parameters_count[parameter.name] = 0
parameters_count[parameter.name] += 1
for name, count in parameters_count.items():
if count > 1:
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
# check if there is a operation id, use $path_$method as operation id if not
if "operationId" not in interface["operation"]:
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = interface["path"]
if interface["path"].startswith("/"):
path = interface["path"][1:]
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
if not path:
path = "<root>"
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
bundles.append(
ApiToolBundle(
server_url=server_url + interface["path"],
method=interface["method"],
summary=interface["operation"]["description"]
if "description" in interface["operation"]
else interface["operation"].get("summary", None),
operation_id=interface["operation"]["operationId"],
parameters=parameters,
author="",
icon=None,
openapi=interface["operation"],
)
)
return bundles
@staticmethod
def _sanitize_default_value(value):
"""
Sanitize default values for PluginParameter compatibility.
Complex types (list, dict) are converted to None to avoid validation errors.
Args:
value: The default value from OpenAPI schema
Returns:
None for complex types (list, dict), otherwise the original value
"""
if isinstance(value, (list, dict)):
return None
return value
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
parameter = parameter or {}
typ: str | None = None
if parameter.get("format") == "binary":
return ToolParameter.ToolParameterType.FILE
if "type" in parameter:
typ = parameter["type"]
elif "schema" in parameter and "type" in parameter["schema"]:
typ = parameter["schema"]["type"]
if typ in {"integer", "number"}:
return ToolParameter.ToolParameterType.NUMBER
elif typ == "boolean":
return ToolParameter.ToolParameterType.BOOLEAN
elif typ == "string":
return ToolParameter.ToolParameterType.STRING
elif typ == "array":
items = parameter.get("items") or parameter.get("schema", {}).get("items")
if items and items.get("format") == "binary":
return ToolParameter.ToolParameterType.FILES
else:
# For regular arrays, return ARRAY type instead of None
return ToolParameter.ToolParameterType.ARRAY
else:
return None
@staticmethod
def parse_openapi_yaml_to_tool_bundle(
yaml: str, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
"""
parse openapi yaml to tool bundle
:param yaml: the yaml string
:param extra_info: the extra info
:param warning: the warning message
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
openapi: dict = safe_load(yaml)
if openapi is None:
raise ToolApiSchemaError("Invalid openapi yaml.")
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
) -> dict[str, Any]:
warning = warning or {}
"""
parse swagger to openapi
:param swagger: the swagger dict
:return: the openapi dict
"""
# convert swagger to openapi
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
servers = swagger.get("servers", [])
if len(servers) == 0:
raise ToolApiSchemaError("No server found in the swagger yaml.")
converted_openapi: dict[str, Any] = {
"openapi": "3.0.0",
"info": {
"title": info.get("title", "Swagger"),
"description": info.get("description", "Swagger"),
"version": info.get("version", "1.0.0"),
},
"servers": swagger["servers"],
"paths": {},
"components": {"schemas": {}},
}
# check paths
if "paths" not in swagger or len(swagger["paths"]) == 0:
raise ToolApiSchemaError("No paths found in the swagger yaml.")
# convert paths
for path, path_item in swagger["paths"].items():
converted_openapi["paths"][path] = {}
for method, operation in path_item.items():
if "operationId" not in operation:
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
if ("summary" not in operation or len(operation["summary"]) == 0) and (
"description" not in operation or len(operation["description"]) == 0
):
if warning is not None:
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
converted_openapi["paths"][path][method] = {
"operationId": operation["operationId"],
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": operation.get("parameters", []),
"responses": operation.get("responses", {}),
}
if "requestBody" in operation:
converted_openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
# convert definitions
if "definitions" in swagger:
for name, definition in swagger["definitions"].items():
converted_openapi["components"]["schemas"][name] = definition
return converted_openapi
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(
json: str, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
"""
parse openapi plugin yaml to tool bundle
:param json: the json string
:param extra_info: the extra info
:param warning: the warning message
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
try:
openai_plugin = json_loads(json)
api = openai_plugin["api"]
api_url = api["url"]
api_type = api["type"]
except JSONDecodeError:
raise ToolProviderNotFoundError("Invalid openai plugin json.")
if api_type != "openapi":
raise ToolNotSupportedError("Only openapi is supported now.")
# get openapi yaml
response = httpx.get(
api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5
)
try:
if response.status_code != 200:
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
response.text, extra_info=extra_info, warning=warning
)
finally:
response.close()
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
) -> tuple[list[ApiToolBundle], str]:
"""
auto parse to tool bundle
:param content: the content
:param extra_info: the extra info
:param warning: the warning message
:return: tools bundle, schema_type
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
content = content.strip()
loaded_content = None
json_error = None
yaml_error = None
try:
loaded_content = json_loads(content)
except JSONDecodeError as e:
json_error = e
if loaded_content is None:
try:
loaded_content = safe_load(content)
except YAMLError as e:
yaml_error = e
if loaded_content is None:
raise ToolApiSchemaError(
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
f" yaml error: {str(yaml_error)}"
)
swagger_error = None
openapi_error = None
openapi_plugin_error = None
schema_type = None
try:
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.OPENAPI
return openapi, schema_type
except ToolApiSchemaError as e:
openapi_error = e
# openai parse error, fallback to swagger
try:
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.SWAGGER
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
converted_swagger, extra_info=extra_info, warning=warning
), schema_type
except ToolApiSchemaError as e:
swagger_error = e
# swagger parse error, fallback to openai plugin
try:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
json_dumps(loaded_content), extra_info=extra_info, warning=warning
)
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN
except ToolNotSupportedError as e:
# maybe it's not plugin at all
openapi_plugin_error = e
raise ToolApiSchemaError(
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
f" openapi plugin error: {str(openapi_plugin_error)}"
)

View File

@@ -0,0 +1,187 @@
import base64
import hashlib
import logging
from collections.abc import Mapping
from typing import Any
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from pydantic import TypeAdapter
from configs import dify_config
logger = logging.getLogger(__name__)
class OAuthEncryptionError(Exception):
"""OAuth encryption/decryption specific error"""
pass
class SystemOAuthEncrypter:
"""
A simple OAuth parameters encrypter using AES-CBC encryption.
This class provides methods to encrypt and decrypt OAuth parameters
using AES-CBC mode with a key derived from the application's SECRET_KEY.
"""
def __init__(self, secret_key: str | None = None):
"""
Initialize the OAuth encrypter.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Raises:
ValueError: If SECRET_KEY is not configured or empty
"""
secret_key = secret_key or dify_config.SECRET_KEY or ""
# Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters.
Args:
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
Returns:
Base64-encoded encrypted string
Raises:
OAuthEncryptionError: If encryption fails
ValueError: If oauth_params is invalid
"""
try:
# Generate random IV (16 bytes)
iv = get_random_bytes(16)
# Create AES cipher (CBC mode)
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data
combined = iv + encrypted_data
# Return base64 encoded string
return base64.b64encode(combined).decode()
except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Raises:
OAuthEncryptionError: If decryption fails
ValueError: If encrypted_data is invalid
"""
if not isinstance(encrypted_data, str):
raise ValueError("encrypted_data must be a string")
if not encrypted_data:
raise ValueError("encrypted_data cannot be empty")
try:
# Base64 decode
combined = base64.b64decode(encrypted_data)
# Check minimum length (IV + at least one AES block)
if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
raise ValueError("Invalid encrypted data format")
# Separate IV and encrypted data
iv = combined[:16]
encrypted_data_bytes = combined[16:]
# Create AES cipher
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Decrypt data
decrypted_data = cipher.decrypt(encrypted_data_bytes)
unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
if not isinstance(oauth_params, dict):
raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params
except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
# Factory function for creating encrypter instances
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
"""
Create an OAuth encrypter instance.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Returns:
SystemOAuthEncrypter instance
"""
return SystemOAuthEncrypter(secret_key=secret_key)
# Global encrypter instance (for backward compatibility)
_oauth_encrypter: SystemOAuthEncrypter | None = None
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
"""
Get the global OAuth encrypter instance.
Returns:
SystemOAuthEncrypter instance
"""
global _oauth_encrypter
if _oauth_encrypter is None:
_oauth_encrypter = SystemOAuthEncrypter()
return _oauth_encrypter
# Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters using the global encrypter.
Args:
oauth_params: OAuth parameters dictionary
Returns:
Base64-encoded encrypted string
"""
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters using the global encrypter.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
"""
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)

View File

@@ -0,0 +1,17 @@
import re
def remove_leading_symbols(text: str) -> str:
"""
Remove leading punctuation or symbols from the given text.
Args:
text (str): The input text to process.
Returns:
str: The text with leading punctuation or symbols removed.
"""
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
return re.sub(pattern, "", text)

View File

@@ -0,0 +1,11 @@
import uuid
def is_valid_uuid(uuid_str: str | None) -> bool:
if uuid_str is None or len(uuid_str) == 0:
return False
try:
uuid.UUID(uuid_str)
return True
except Exception:
return False

View File

@@ -0,0 +1,128 @@
import mimetypes
import re
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, cast
from urllib.parse import unquote
import chardet
import cloudscraper
from readabilipy import simple_json_from_html_string
from core.helper import ssrf_proxy
from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor
FULL_TEMPLATE = """
TITLE: {title}
AUTHOR: {author}
TEXT:
{text}
"""
def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
return text[cursor : cursor + max_length]
def get_url(url: str, user_agent: str | None = None) -> str:
"""Fetch URL and return the contents as a string."""
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
" Chrome/91.0.4472.124 Safari/537.36"
}
if user_agent:
headers["User-Agent"] = user_agent
main_content_type = None
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
if response.status_code == 200:
# check content-type
content_type = response.headers.get("Content-Type")
if content_type:
main_content_type = response.headers.get("Content-Type").split(";")[0].strip()
else:
content_disposition = response.headers.get("Content-Disposition", "")
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
extension = re.search(r"\.(\w+)$", filename)
if extension:
main_content_type = mimetypes.guess_type(filename)[0]
if main_content_type not in supported_content_types:
return f"Unsupported content-type [{main_content_type}] of URL."
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
scraper = cloudscraper.create_scraper()
scraper.perform_request = ssrf_proxy.make_request
response = scraper.get(url, headers=headers, timeout=(120, 300))
if response.status_code != 200:
return f"URL returned status code {response.status_code}."
# Detect encoding using chardet
detected_encoding = chardet.detect(response.content)
encoding = detected_encoding["encoding"]
if encoding:
try:
content = response.content.decode(encoding)
except (UnicodeDecodeError, TypeError):
content = response.text
else:
content = response.text
article = extract_using_readabilipy(content)
if not article.text:
return ""
res = FULL_TEMPLATE.format(
title=article.title,
author=article.author,
text=article.text,
)
return res
@dataclass
class Article:
title: str
author: str
text: Sequence[dict]
def extract_using_readabilipy(html: str):
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
article = Article(
title=json_article.get("title") or "",
author=json_article.get("byline") or "",
text=json_article.get("plain_text") or [],
)
return article
def get_image_upload_file_ids(content):
pattern = r"!\[image\]\((https?://.*?(file-preview|image-preview))\)"
matches = re.findall(pattern, content)
image_upload_file_ids = []
for match in matches:
if match[1] == "file-preview":
content_pattern = r"files/([^/]+)/file-preview"
else:
content_pattern = r"files/([^/]+)/image-preview"
content_match = re.search(content_pattern, match[0])
if content_match:
image_upload_file_id = content_match.group(1)
image_upload_file_ids.append(image_upload_file_id)
return image_upload_file_ids

View File

@@ -0,0 +1,43 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
"""
get workflow graph variables
"""
nodes = graph.get("nodes", [])
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
if not start_node:
return []
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
):
"""
check is synced
raise ValueError if not synced
"""
variable_names = [variable.variable for variable in variables]
if len(tool_configurations) != len(variables):
raise ValueError("parameter configuration mismatch, please republish the tool to update")
for parameter in tool_configurations:
if parameter.name not in variable_names:
raise ValueError("parameter configuration mismatch, please republish the tool to update")

View File

@@ -0,0 +1,33 @@
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any
import yaml
from yaml import YAMLError
logger = logging.getLogger(__name__)
def _load_yaml_file(*, file_path: str):
if not file_path or not Path(file_path).exists():
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content
except Exception as e:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
@lru_cache(maxsize=128)
def load_yaml_file_cached(file_path: str) -> Any:
"""
Cached version of load_yaml_file for static configuration files.
Only use for files that don't change during runtime (e.g., position files)
:param file_path: the path of the YAML file
:return: an object of the YAML content
"""
return _load_yaml_file(file_path=file_path)

View File

@@ -0,0 +1,240 @@
from collections.abc import Mapping
from pydantic import Field
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.plugin.entities.parameters import PluginParameterOption
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolDescription,
ToolEntity,
ToolIdentity,
ToolParameter,
ToolProviderEntity,
ToolProviderIdentity,
ToolProviderType,
)
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import Workflow
VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
VariableEntityType.CHECKBOX: ToolParameter.ToolParameterType.BOOLEAN,
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
}
class WorkflowToolProviderController(ToolProviderController):
provider_id: str
tools: list[WorkflowTool] = Field(default_factory=list)
def __init__(self, entity: ToolProviderEntity, provider_id: str):
super().__init__(entity=entity)
self.provider_id = provider_id
@classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
with Session(db.engine, expire_on_commit=False) as session, session.begin():
provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None
if not provider:
raise ValueError("workflow provider not found")
app = session.get(App, provider.app_id)
if not app:
raise ValueError("app not found")
user = session.get(Account, provider.user_id) if provider.user_id else None
controller = WorkflowToolProviderController(
entity=ToolProviderEntity(
identity=ToolProviderIdentity(
author=user.name if user else "",
name=provider.label,
label=I18nObject(en_US=provider.label, zh_Hans=provider.label),
description=I18nObject(en_US=provider.description, zh_Hans=provider.description),
icon=provider.icon,
),
credentials_schema=[],
plugin_id=None,
),
provider_id=provider.id or "",
)
controller.tools = [
controller._get_db_provider_tool(provider, app, session=session, user=user),
]
return controller
@property
def provider_type(self) -> ToolProviderType:
return ToolProviderType.WORKFLOW
def _get_db_provider_tool(
self,
db_provider: WorkflowToolProvider,
app: App,
*,
session: Session,
user: Account | None = None,
) -> WorkflowTool:
"""
get db provider tool
:param db_provider: the db provider
:param app: the app
:return: the tool
"""
workflow: Workflow | None = (
session.query(Workflow)
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first()
)
if not workflow:
raise ValueError("workflow not found")
# fetch start node
graph: Mapping = workflow.graph_dict
features_dict: Mapping = workflow.features_dict
features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW)
parameters = db_provider.parameter_configurations
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
return next(filter(lambda x: x.variable == variable_name, variables), None)
workflow_tool_parameters = []
for parameter in parameters:
variable = fetch_workflow_variable(parameter.name)
if variable:
parameter_type = None
options = []
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
raise ValueError(f"unsupported variable type {variable.type}")
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
if variable.type == VariableEntityType.SELECT and variable.options:
options = [
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
for option in variable.options
]
workflow_tool_parameters.append(
ToolParameter(
name=parameter.name,
label=I18nObject(en_US=variable.label, zh_Hans=variable.label),
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
type=parameter_type,
form=parameter.form,
llm_description=parameter.description,
required=variable.required,
default=variable.default,
options=options,
placeholder=I18nObject(en_US="", zh_Hans=""),
)
)
elif features.file_upload:
workflow_tool_parameters.append(
ToolParameter(
name=parameter.name,
label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name),
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
type=ToolParameter.ToolParameterType.SYSTEM_FILES,
llm_description=parameter.description,
required=False,
form=parameter.form,
placeholder=I18nObject(en_US="", zh_Hans=""),
)
)
else:
raise ValueError("variable not found")
return WorkflowTool(
workflow_as_tool_id=db_provider.id,
entity=ToolEntity(
identity=ToolIdentity(
author=user.name if user else "",
name=db_provider.name,
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
provider=self.provider_id,
icon=db_provider.icon,
),
description=ToolDescription(
human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
llm=db_provider.description,
),
parameters=workflow_tool_parameters,
),
runtime=ToolRuntime(
tenant_id=db_provider.tenant_id,
),
workflow_app_id=app.id,
workflow_entities={
"app": app,
"workflow": workflow,
},
version=db_provider.version,
workflow_call_depth=0,
label=db_provider.label,
)
def get_tools(self, tenant_id: str) -> list[WorkflowTool]:
"""
fetch tools from database
:param tenant_id: the tenant id
:return: the tools
"""
if self.tools is not None:
return self.tools
with Session(db.engine, expire_on_commit=False) as session, session.begin():
db_provider: WorkflowToolProvider | None = (
session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
)
.first()
)
if not db_provider:
return []
app = session.get(App, db_provider.app_id)
if not app:
raise ValueError("app not found")
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
self.tools = [self._get_db_provider_tool(db_provider, app, session=session, user=user)]
return self.tools
def get_tool(self, tool_name: str) -> WorkflowTool | None: # type: ignore
"""
get tool by name
:param tool_name: the name of the tool
:return: the tool
"""
if self.tools is None:
return None
for tool in self.tools:
if tool.entity.identity.name == tool_name:
return tool
return None

View File

@@ -0,0 +1,346 @@
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from flask import has_request_context
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import (
ToolEntity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs.login import current_user
from models import Account, Tenant
from models.model import App, EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class WorkflowTool(Tool):
"""
Workflow tool.
"""
def __init__(
self,
workflow_app_id: str,
workflow_as_tool_id: str,
version: str,
workflow_entities: dict[str, Any],
workflow_call_depth: int,
entity: ToolEntity,
runtime: ToolRuntime,
label: str = "Workflow",
):
self.workflow_app_id = workflow_app_id
self.workflow_as_tool_id = workflow_as_tool_id
self.version = version
self.workflow_entities = workflow_entities
self.workflow_call_depth = workflow_call_depth
self.label = label
self._latest_usage = LLMUsage.empty_usage()
super().__init__(entity=entity, runtime=runtime)
def tool_provider_type(self) -> ToolProviderType:
"""
get the tool provider type
:return: the tool provider type
"""
return ToolProviderType.WORKFLOW
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke the tool
"""
app = self._get_app(app_id=self.workflow_app_id)
workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
# transform the tool parameters
tool_parameters, files = self._transform_args(tool_parameters=tool_parameters)
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
generator = WorkflowAppGenerator()
assert self.runtime is not None
assert self.runtime.invoke_from is not None
user = self._resolve_user(user_id=user_id)
if user is None:
raise ToolInvokeError("User not found")
self._latest_usage = LLMUsage.empty_usage()
result = generator.generate(
app_model=app,
workflow=workflow,
user=user,
args={"inputs": tool_parameters, "files": files},
invoke_from=self.runtime.invoke_from,
streaming=False,
call_depth=self.workflow_call_depth + 1,
)
assert isinstance(result, dict)
data = result.get("data", {})
if err := data.get("error"):
raise ToolInvokeError(err)
outputs = data.get("outputs")
if outputs is None:
outputs = {}
else:
outputs, files = self._extract_files(outputs) # type: ignore
for file in files:
yield self.create_file_message(file) # type: ignore
self._latest_usage = self._derive_usage_from_result(data)
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs, suppress_output=True)
@property
def latest_usage(self) -> LLMUsage:
return self._latest_usage
@classmethod
def _derive_usage_from_result(cls, data: Mapping[str, Any]) -> LLMUsage:
usage_dict = cls._extract_usage_dict(data)
if usage_dict is not None:
return LLMUsage.from_metadata(cast(LLMUsageMetadata, dict(usage_dict)))
total_tokens = data.get("total_tokens")
total_price = data.get("total_price")
if total_tokens is None and total_price is None:
return LLMUsage.empty_usage()
usage_metadata: dict[str, Any] = {}
if total_tokens is not None:
try:
usage_metadata["total_tokens"] = int(str(total_tokens))
except (TypeError, ValueError):
pass
if total_price is not None:
usage_metadata["total_price"] = str(total_price)
currency = data.get("currency")
if currency is not None:
usage_metadata["currency"] = currency
if not usage_metadata:
return LLMUsage.empty_usage()
return LLMUsage.from_metadata(cast(LLMUsageMetadata, usage_metadata))
@classmethod
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
usage_candidate = payload.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
metadata_candidate = payload.get("metadata")
if isinstance(metadata_candidate, Mapping):
usage_candidate = metadata_candidate.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
for value in payload.values():
if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
return None
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
"""
fork a new tool with metadata
:return: the new tool
"""
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
workflow_app_id=self.workflow_app_id,
workflow_as_tool_id=self.workflow_as_tool_id,
workflow_entities=self.workflow_entities,
workflow_call_depth=self.workflow_call_depth,
version=self.version,
label=self.label,
)
def _resolve_user(self, user_id: str) -> Account | EndUser | None:
"""
Resolve user object in both HTTP and worker contexts.
In HTTP context: dereference the current_user LocalProxy (can return Account or EndUser).
In worker context: load Account from database by user_id (only returns Account, never EndUser).
Returns:
Account | EndUser | None: The resolved user object, or None if resolution fails.
"""
if has_request_context():
return self._resolve_user_from_request()
else:
return self._resolve_user_from_database(user_id=user_id)
def _resolve_user_from_request(self) -> Account | EndUser | None:
"""
Resolve user from Flask request context.
"""
try:
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
return getattr(current_user, "_get_current_object", lambda: current_user)()
except Exception as e:
logger.warning("Failed to resolve user from request context: %s", e)
return None
def _resolve_user_from_database(self, user_id: str) -> Account | None:
"""
Resolve user from database (worker/Celery context).
"""
user_stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(user_stmt)
if not user:
return None
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = db.session.scalar(tenant_stmt)
if not tenant:
return None
user.current_tenant = tenant
return user
def _get_workflow(self, app_id: str, version: str) -> Workflow:
"""
get the workflow by app id and version
"""
with Session(db.engine, expire_on_commit=False) as session, session.begin():
if not version:
stmt = (
select(Workflow)
.where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT)
.order_by(Workflow.created_at.desc())
)
workflow = session.scalars(stmt).first()
else:
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = session.scalar(stmt)
if not workflow:
raise ValueError("workflow not found or not published")
return workflow
def _get_app(self, app_id: str) -> App:
"""
get the app by app id
"""
stmt = select(App).where(App.id == app_id)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
app = session.scalar(stmt)
if not app:
raise ValueError("app not found")
return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
"""
transform the tool parameters
:param tool_parameters: the tool parameters
:return: tool_parameters, files
"""
parameter_rules = self.get_merged_runtime_parameters()
parameters_result = {}
files = []
for parameter in parameter_rules:
if parameter.type == ToolParameter.ToolParameterType.SYSTEM_FILES:
file = tool_parameters.get(parameter.name)
if file:
try:
file_var_list = [File.model_validate(f) for f in file]
for file in file_var_list:
file_dict: dict[str, str | None] = {
"transfer_method": file.transfer_method.value,
"type": file.type.value,
}
if file.transfer_method == FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file.related_id
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
file_dict["upload_file_id"] = file.related_id
elif file.transfer_method == FileTransferMethod.REMOTE_URL:
file_dict["url"] = file.generate_url()
files.append(file_dict)
except Exception:
logger.exception("Failed to transform file %s", file)
else:
parameters_result[parameter.name] = tool_parameters.get(parameter.name)
return parameters_result, files
def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]:
"""
extract files from the result
:return: the result, files
"""
files: list[File] = []
result = {}
for key, value in outputs.items():
if isinstance(value, list):
for item in value:
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
item = self._update_file_mapping(item)
file = build_from_mapping(
mapping=item,
tenant_id=str(self.runtime.tenant_id),
)
files.append(file)
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
value = self._update_file_mapping(value)
file = build_from_mapping(
mapping=value,
tenant_id=str(self.runtime.tenant_id),
)
files.append(file)
result[key] = value
return result, files
def _update_file_mapping(self, file_dict: dict):
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
if transfer_method == FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file_dict.get("related_id")
elif transfer_method == FileTransferMethod.LOCAL_FILE:
file_dict["upload_file_id"] = file_dict.get("related_id")
return file_dict