dify
This commit is contained in:
233
dify/api/core/tools/__base/tool.py
Normal file
233
dify/api/core/tools/__base/tool.py
Normal file
@@ -0,0 +1,233 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import File
|
||||
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""
|
||||
The base class of a tool
|
||||
"""
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime):
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
get the tool provider type
|
||||
|
||||
:return: the tool provider type
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage]:
|
||||
if self.runtime and self.runtime.runtime_parameters:
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
|
||||
# try parse tool parameters into the correct type
|
||||
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
|
||||
|
||||
result = self._invoke(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
if isinstance(result, ToolInvokeMessage):
|
||||
|
||||
def single_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield result
|
||||
|
||||
return single_generator()
|
||||
elif isinstance(result, list):
|
||||
|
||||
def generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield from result
|
||||
|
||||
return generator()
|
||||
else:
|
||||
return result
|
||||
|
||||
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Transform tool parameters type
|
||||
"""
|
||||
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
|
||||
result = deepcopy(tool_parameters)
|
||||
for parameter in self.entity.parameters or []:
|
||||
if parameter.name in tool_parameters:
|
||||
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
|
||||
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
|
||||
pass
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get the runtime parameters
|
||||
|
||||
interface for developer to dynamic change the parameters of a tool depends on the variables pool
|
||||
|
||||
:return: the runtime parameters
|
||||
"""
|
||||
return self.entity.parameters
|
||||
|
||||
def get_merged_runtime_parameters(
|
||||
self,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get merged runtime parameters
|
||||
|
||||
:return: merged runtime parameters
|
||||
"""
|
||||
parameters = self.entity.parameters
|
||||
parameters = parameters.copy()
|
||||
user_parameters = self.get_runtime_parameters() or []
|
||||
user_parameters = user_parameters.copy()
|
||||
|
||||
# override parameters
|
||||
for parameter in user_parameters:
|
||||
# check if parameter in tool parameters
|
||||
for tool_parameter in parameters:
|
||||
if tool_parameter.name == parameter.name:
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
break
|
||||
else:
|
||||
# add new parameter
|
||||
parameters.append(parameter)
|
||||
|
||||
return parameters
|
||||
|
||||
def create_image_message(
|
||||
self,
|
||||
image: str,
|
||||
) -> ToolInvokeMessage:
|
||||
"""
|
||||
create an image message
|
||||
|
||||
:param image: the url of the image
|
||||
:return: the image message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
|
||||
)
|
||||
|
||||
def create_file_message(self, file: "File") -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.FILE,
|
||||
message=ToolInvokeMessage.FileMessage(),
|
||||
meta={"file": file},
|
||||
)
|
||||
|
||||
def create_link_message(self, link: str) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a link message
|
||||
|
||||
:param link: the url of the link
|
||||
:return: the link message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link)
|
||||
)
|
||||
|
||||
def create_text_message(self, text: str) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a text message
|
||||
|
||||
:param text: the text
|
||||
:return: the text message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(text=text),
|
||||
)
|
||||
|
||||
def create_blob_message(self, blob: bytes, meta: dict | None = None) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a blob message
|
||||
|
||||
:param blob: the blob
|
||||
:param meta: the meta info of blob object
|
||||
:return: the blob message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB,
|
||||
message=ToolInvokeMessage.BlobMessage(blob=blob),
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a json message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.JSON,
|
||||
message=ToolInvokeMessage.JsonMessage(json_object=object, suppress_output=suppress_output),
|
||||
)
|
||||
|
||||
def create_variable_message(
|
||||
self, variable_name: str, variable_value: Any, stream: bool = False
|
||||
) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a variable message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.VARIABLE,
|
||||
message=ToolInvokeMessage.VariableMessage(
|
||||
variable_name=variable_name, variable_value=variable_value, stream=stream
|
||||
),
|
||||
)
|
||||
107
dify/api/core/tools/__base/tool_provider.py
Normal file
107
dify/api/core/tools/__base/tool_provider.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolProviderEntity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class ToolProviderController(ABC):
|
||||
def __init__(self, entity: ToolProviderEntity):
|
||||
self.entity = entity
|
||||
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
return deepcopy(self.entity.credentials_schema)
|
||||
|
||||
@abstractmethod
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
"""
|
||||
returns a tool that the provider can provide
|
||||
|
||||
:return: tool
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = dict[str, ProviderConfig]()
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
for credential in self.entity.credentials_schema:
|
||||
credentials_schema[credential.name] = credential
|
||||
|
||||
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||
for credential_name in credentials_schema:
|
||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} not found in provider {self.entity.identity.name}"
|
||||
)
|
||||
|
||||
# check type
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if not credential_schema.required and credentials[credential_name] is None:
|
||||
continue
|
||||
|
||||
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
options = credential_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
|
||||
|
||||
if credentials[credential_name] not in [x.value for x in options]:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} should be one of {options}"
|
||||
)
|
||||
|
||||
credentials_need_to_validate.pop(credential_name)
|
||||
|
||||
for credential_name in credentials_need_to_validate:
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema.required:
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
|
||||
|
||||
# the credential is not set currently, set the default value if needed
|
||||
if credential_schema.default is not None:
|
||||
default_value = credential_schema.default
|
||||
# parse default value into the correct type
|
||||
if credential_schema.type in {
|
||||
ProviderConfig.Type.SECRET_INPUT,
|
||||
ProviderConfig.Type.TEXT_INPUT,
|
||||
ProviderConfig.Type.SELECT,
|
||||
}:
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
||||
37
dify/api/core/tools/__base/tool_runtime.py
Normal file
37
dify/api/core/tools/__base/tool_runtime.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||
|
||||
|
||||
class ToolRuntime(BaseModel):
|
||||
"""
|
||||
Meta data of a tool call processing
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
tool_id: str | None = None
|
||||
invoke_from: InvokeFrom | None = None
|
||||
tool_invoke_from: ToolInvokeFrom | None = None
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
credential_type: CredentialType = Field(default=CredentialType.API_KEY)
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class FakeToolRuntime(ToolRuntime):
|
||||
"""
|
||||
Fake tool runtime for testing
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
tenant_id="fake_tenant_id",
|
||||
tool_id="fake_tool_id",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
credentials={},
|
||||
runtime_parameters={},
|
||||
)
|
||||
0
dify/api/core/tools/__init__.py
Normal file
0
dify/api/core/tools/__init__.py
Normal file
4
dify/api/core/tools/builtin_tool/_position.yaml
Normal file
4
dify/api/core/tools/builtin_tool/_position.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
- audio
|
||||
- code
|
||||
- time
|
||||
- webscraper
|
||||
221
dify/api/core/tools/builtin_tool/provider.py
Normal file
221
dify/api/core/tools/builtin_tool/provider.py
Normal file
@@ -0,0 +1,221 @@
|
||||
from abc import abstractmethod
|
||||
from os import listdir, path
|
||||
from typing import Any
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import (
|
||||
OAuthSchema,
|
||||
ToolEntity,
|
||||
ToolProviderEntity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
||||
from core.tools.errors import (
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.utils.yaml_utils import load_yaml_file_cached
|
||||
|
||||
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
tools: list[BuiltinTool]
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
self.tools = []
|
||||
|
||||
# load provider yaml
|
||||
provider = self.__class__.__module__.split(".")[-1]
|
||||
yaml_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, f"{provider}.yaml")
|
||||
try:
|
||||
provider_yaml = load_yaml_file_cached(yaml_path)
|
||||
except Exception as e:
|
||||
raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}")
|
||||
|
||||
if "credentials_for_provider" in provider_yaml and provider_yaml["credentials_for_provider"] is not None:
|
||||
# set credentials name
|
||||
for credential_name in provider_yaml["credentials_for_provider"]:
|
||||
provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name
|
||||
|
||||
credentials_schema = []
|
||||
for credential in provider_yaml.get("credentials_for_provider", {}):
|
||||
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
|
||||
credentials_schema.append(credential_dict)
|
||||
|
||||
oauth_schema = None
|
||||
if provider_yaml.get("oauth_schema", None) is not None:
|
||||
oauth_schema = OAuthSchema(
|
||||
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
|
||||
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
entity=ToolProviderEntity(
|
||||
identity=provider_yaml["identity"],
|
||||
credentials_schema=credentials_schema,
|
||||
oauth_schema=oauth_schema,
|
||||
),
|
||||
)
|
||||
|
||||
self._load_tools()
|
||||
|
||||
def _load_tools(self):
|
||||
provider = self.entity.identity.name
|
||||
tool_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, "tools")
|
||||
# get all the yaml files in the tool path
|
||||
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
|
||||
tools = []
|
||||
for tool_file in tool_files:
|
||||
# get tool name
|
||||
tool_name = tool_file.split(".")[0]
|
||||
tool = load_yaml_file_cached(path.join(tool_path, tool_file))
|
||||
|
||||
# get tool class, import the module
|
||||
assistant_tool_class: type = load_single_subclass_from_source(
|
||||
module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}",
|
||||
script_path=path.join(
|
||||
path.dirname(path.realpath(__file__)),
|
||||
"builtin_tool",
|
||||
"providers",
|
||||
provider,
|
||||
"tools",
|
||||
f"{tool_name}.py",
|
||||
),
|
||||
parent_type=BuiltinTool,
|
||||
)
|
||||
tool["identity"]["provider"] = provider
|
||||
tools.append(
|
||||
assistant_tool_class(
|
||||
provider=provider,
|
||||
entity=ToolEntity.model_validate(tool),
|
||||
runtime=ToolRuntime(tenant_id=""),
|
||||
)
|
||||
)
|
||||
|
||||
self.tools = tools
|
||||
|
||||
def _get_builtin_tools(self) -> list[BuiltinTool]:
|
||||
"""
|
||||
returns a list of tools that the provider can provide
|
||||
|
||||
:return: list of tools
|
||||
"""
|
||||
return self.tools
|
||||
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
return self.get_credentials_schema_by_type(CredentialType.API_KEY)
|
||||
|
||||
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:param credential_type: the type of the credential
|
||||
:return: the credentials schema of the provider
|
||||
"""
|
||||
if credential_type == CredentialType.OAUTH2.value:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the oauth client schema of the provider
|
||||
|
||||
:return: the oauth client schema
|
||||
"""
|
||||
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
|
||||
|
||||
def get_supported_credential_types(self) -> list[CredentialType]:
|
||||
"""
|
||||
returns the credential support type of the provider
|
||||
"""
|
||||
types = []
|
||||
if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0:
|
||||
types.append(CredentialType.API_KEY)
|
||||
if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0:
|
||||
types.append(CredentialType.OAUTH2)
|
||||
return types
|
||||
|
||||
def get_tools(self) -> list[BuiltinTool]:
|
||||
"""
|
||||
returns a list of tools that the provider can provide
|
||||
|
||||
:return: list of tools
|
||||
"""
|
||||
return self._get_builtin_tools()
|
||||
|
||||
def get_tool(self, tool_name: str) -> BuiltinTool | None: # type: ignore
|
||||
"""
|
||||
returns the tool that the provider can provide
|
||||
"""
|
||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
"""
|
||||
returns whether the provider needs credentials
|
||||
|
||||
:return: whether the provider needs credentials
|
||||
"""
|
||||
return (
|
||||
self.entity.credentials_schema is not None
|
||||
and len(self.entity.credentials_schema) != 0
|
||||
or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0)
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
@property
|
||||
def tool_labels(self) -> list[str]:
|
||||
"""
|
||||
returns the labels of the provider
|
||||
|
||||
:return: labels of the provider
|
||||
"""
|
||||
label_enums = self._get_tool_labels()
|
||||
return [default_tool_label_dict[label].name for label in label_enums]
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
"""
|
||||
returns the labels of the provider
|
||||
"""
|
||||
return self.entity.identity.tags or []
|
||||
|
||||
def validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
:param user_id: use id
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
# validate credentials format
|
||||
self.validate_credentials_format(credentials)
|
||||
|
||||
# validate credentials
|
||||
self._validate_credentials(user_id, credentials)
|
||||
|
||||
@abstractmethod
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
:param user_id: use id
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
pass
|
||||
20
dify/api/core/tools/builtin_tool/providers/_positions.py
Normal file
20
dify/api/core/tools/builtin_tool/providers/_positions.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import os.path
|
||||
|
||||
from core.helper.position_helper import get_tool_position_map, sort_by_position_map
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
|
||||
|
||||
class BuiltinToolProviderSort:
|
||||
_position: dict[str, int] = {}
|
||||
|
||||
@classmethod
|
||||
def sort(cls, providers: list[ToolProviderApiEntity]) -> list[ToolProviderApiEntity]:
|
||||
if not cls._position:
|
||||
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
def name_func(provider: ToolProviderApiEntity) -> str:
|
||||
return provider.name
|
||||
|
||||
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
|
||||
|
||||
return sorted_providers
|
||||
@@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="200" height="200" viewBox="0 0 200 200" fill="none">
|
||||
<path d="M167.358 102.395C167.358 117.174 157.246 129.18 144.61 131.027H137.861C125.225 129.18 115.113 117.174 115.113 102.395H100.792C100.792 123.637 115.118 142.106 133.653 145.801V164.276H147.139V145.801C165.674 142.106 180 124.558 180 102.4H167.358V102.395ZM154.717 62.677C154.717 53.4397 147.979 46.9765 140.396 46.9765C138.523 46.9446 136.663 47.3273 134.924 48.1024C133.185 48.8775 131.603 50.0294 130.27 51.4909C128.936 52.9524 127.878 54.6943 127.157 56.6148C126.436 58.5354 126.066 60.5962 126.07 62.677V78.3775H154.717V70.4478V62.677ZM126.07 102.395C126.07 111.632 132.813 118.095 140.396 118.095C142.269 118.127 144.13 117.744 145.868 116.969C147.607 116.194 149.189 115.042 150.523 113.581C151.856 112.119 152.914 110.377 153.635 108.457C154.356 106.536 154.726 104.475 154.722 102.395V86.694H126.07V102.395ZM92.1297 45.8938L70.4796 21.7595L69.4235 20.5865L59.604 20L68.3674 20.5865L67.3113 21.7654L64.1429 25.2961L63.6149 25.8826L64.1429 27.0614L66.2552 29.4133L77.8723 42.3631H54.1099C35.1 43.5361 20.3146 61.1896 20.3146 81.7874V83.5527H28.2354V81.7932C28.2354 65.8992 39.8525 52.3628 54.1099 51.1899H77.8723L66.2552 64.1338L64.671 65.8992L64.1429 67.0722L63.6149 67.6645L64.1429 68.251L68.3674 72.9606L68.8954 73.5471L69.4235 72.9606L74.1759 67.6645L92.1297 47.6591L92.6578 47.0727L92.1297 45.8938ZM20 95.8496V118.213H30.033V107.034H50.099V168.821H40.066V180H70.165V168.821H60.132V107.034H80.198V118.213H90.231V95.8496H20Z" fill="#FF0099"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.5 KiB |
@@ -0,0 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AudioToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
11
dify/api/core/tools/builtin_tool/providers/audio/audio.yaml
Normal file
11
dify/api/core/tools/builtin_tool/providers/audio/audio.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
identity:
|
||||
author: hjlarry
|
||||
name: audio
|
||||
label:
|
||||
en_US: Audio
|
||||
description:
|
||||
en_US: A tool for tts and asr.
|
||||
zh_Hans: 一个用于文本转语音和语音转文本的工具。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- utilities
|
||||
@@ -0,0 +1,84 @@
|
||||
import io
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.file.enums import FileType
|
||||
from core.file.file_manager import download
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class ASRTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
file = tool_parameters.get("audio_file")
|
||||
if file.type != FileType.AUDIO: # type: ignore
|
||||
yield self.create_text_message("not a valid audio file")
|
||||
return
|
||||
audio_binary = io.BytesIO(download(file)) # type: ignore
|
||||
audio_binary.name = "temp.mp3"
|
||||
provider, model = tool_parameters.get("model").split("#") # type: ignore
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
provider=provider,
|
||||
model_type=ModelType.SPEECH2TEXT,
|
||||
model=model,
|
||||
)
|
||||
text = model_instance.invoke_speech2text(
|
||||
file=audio_binary,
|
||||
user=user_id,
|
||||
)
|
||||
yield self.create_text_message(text)
|
||||
|
||||
def get_available_models(self) -> list[tuple[str, str]]:
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(
|
||||
tenant_id=self.runtime.tenant_id, model_type="speech2text"
|
||||
)
|
||||
items = []
|
||||
for provider_model in models:
|
||||
provider = provider_model.provider
|
||||
for model in provider_model.models:
|
||||
items.append((provider, model.model))
|
||||
return items
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> list[ToolParameter]:
|
||||
parameters = []
|
||||
|
||||
options = []
|
||||
for provider, model in self.get_available_models():
|
||||
option = PluginParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
|
||||
options.append(option)
|
||||
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name="model",
|
||||
label=I18nObject(en_US="Model", zh_Hans="Model"),
|
||||
human_description=I18nObject(
|
||||
en_US="All available ASR models. You can config model in the Model Provider of Settings.",
|
||||
zh_Hans="所有可用的 ASR 模型。你可以在设置中的模型供应商里配置。",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=True,
|
||||
options=options,
|
||||
)
|
||||
)
|
||||
return parameters
|
||||
@@ -0,0 +1,22 @@
|
||||
identity:
|
||||
name: asr
|
||||
author: hjlarry
|
||||
label:
|
||||
en_US: Speech To Text
|
||||
description:
|
||||
human:
|
||||
en_US: Convert audio file to text.
|
||||
zh_Hans: 将音频文件转换为文本。
|
||||
llm: Convert audio file to text.
|
||||
parameters:
|
||||
- name: audio_file
|
||||
type: file
|
||||
required: true
|
||||
label:
|
||||
en_US: Audio File
|
||||
zh_Hans: 音频文件
|
||||
human_description:
|
||||
en_US: The audio file to be converted.
|
||||
zh_Hans: 要转换的音频文件。
|
||||
llm_description: The audio file to be converted.
|
||||
form: llm
|
||||
116
dify/api/core/tools/builtin_tool/providers/audio/tools/tts.py
Normal file
116
dify/api/core/tools/builtin_tool/providers/audio/tools/tts.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import io
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class TTSTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
provider, model = tool_parameters.get("model").split("#") # type: ignore
|
||||
voice = tool_parameters.get(f"voice#{provider}#{model}")
|
||||
model_manager = ModelManager()
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
provider=provider,
|
||||
model_type=ModelType.TTS,
|
||||
model=model,
|
||||
)
|
||||
if not voice:
|
||||
voices = model_instance.get_tts_voices()
|
||||
if voices:
|
||||
voice = voices[0].get("value")
|
||||
if not voice:
|
||||
raise ValueError("Sorry, no voice available.")
|
||||
else:
|
||||
raise ValueError("Sorry, no voice available.")
|
||||
tts = model_instance.invoke_tts(
|
||||
content_text=tool_parameters.get("text"), # type: ignore
|
||||
user=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
voice=voice,
|
||||
)
|
||||
buffer = io.BytesIO()
|
||||
for chunk in tts:
|
||||
buffer.write(chunk)
|
||||
|
||||
wav_bytes = buffer.getvalue()
|
||||
yield self.create_text_message("Audio generated successfully")
|
||||
yield self.create_blob_message(
|
||||
blob=wav_bytes,
|
||||
meta={"mime_type": "audio/x-wav"},
|
||||
)
|
||||
|
||||
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
model_provider_service = ModelProviderService()
|
||||
tid: str = self.runtime.tenant_id or ""
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts")
|
||||
items = []
|
||||
for provider_model in models:
|
||||
provider = provider_model.provider
|
||||
for model in provider_model.models:
|
||||
voices = model.model_properties.get(ModelPropertyKey.VOICES, [])
|
||||
items.append((provider, model.model, voices))
|
||||
return items
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> list[ToolParameter]:
|
||||
parameters = []
|
||||
|
||||
options = []
|
||||
for provider, model, voices in self.get_available_models():
|
||||
option = PluginParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
|
||||
options.append(option)
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=f"voice#{provider}#{model}",
|
||||
label=I18nObject(en_US=f"Voice of {model}({provider})"),
|
||||
human_description=I18nObject(en_US=f"Select a voice for {model} model"),
|
||||
placeholder=I18nObject(en_US="Select a voice"),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
options=[
|
||||
PluginParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name")))
|
||||
for voice in voices
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
parameters.insert(
|
||||
0,
|
||||
ToolParameter(
|
||||
name="model",
|
||||
label=I18nObject(en_US="Model", zh_Hans="Model"),
|
||||
human_description=I18nObject(
|
||||
en_US="All available TTS models. You can config model in the Model Provider of Settings.",
|
||||
zh_Hans="所有可用的 TTS 模型。你可以在设置中的模型供应商里配置。",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=True,
|
||||
placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"),
|
||||
options=options,
|
||||
),
|
||||
)
|
||||
return parameters
|
||||
@@ -0,0 +1,22 @@
|
||||
identity:
|
||||
name: tts
|
||||
author: hjlarry
|
||||
label:
|
||||
en_US: Text To Speech
|
||||
description:
|
||||
human:
|
||||
en_US: Convert text to audio file.
|
||||
zh_Hans: 将文本转换为音频文件。
|
||||
llm: Convert text to audio file.
|
||||
parameters:
|
||||
- name: text
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Text
|
||||
zh_Hans: 文本
|
||||
human_description:
|
||||
en_US: The text to be converted.
|
||||
zh_Hans: 要转换的文本。
|
||||
llm_description: The text to be converted.
|
||||
form: llm
|
||||
@@ -0,0 +1 @@
|
||||
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg" class="w-3.5 h-3.5" data-icon="Code" aria-hidden="true"><g id="icons/code"><path id="Vector (Stroke)" fill-rule="evenodd" clip-rule="evenodd" d="M8.32593 1.69675C8.67754 1.78466 8.89132 2.14096 8.80342 2.49257L6.47009 11.8259C6.38218 12.1775 6.02588 12.3913 5.67427 12.3034C5.32265 12.2155 5.10887 11.8592 5.19678 11.5076L7.53011 2.17424C7.61801 1.82263 7.97431 1.60885 8.32593 1.69675ZM3.96414 4.20273C4.22042 4.45901 4.22042 4.87453 3.96413 5.13081L2.45578 6.63914C2.45577 6.63915 2.45578 6.63914 2.45578 6.63914C2.25645 6.83851 2.25643 7.16168 2.45575 7.36103C2.45574 7.36103 2.45576 7.36104 2.45575 7.36103L3.96413 8.86936C4.22041 9.12564 4.22042 9.54115 3.96414 9.79744C3.70787 10.0537 3.29235 10.0537 3.03607 9.79745L1.52769 8.28913C0.815811 7.57721 0.815803 6.42302 1.52766 5.7111L3.03606 4.20272C3.29234 3.94644 3.70786 3.94644 3.96414 4.20273ZM10.0361 4.20273C10.2923 3.94644 10.7078 3.94644 10.9641 4.20272L12.4725 5.71108C13.1843 6.423 13.1844 7.57717 12.4725 8.28909L10.9641 9.79745C10.7078 10.0537 10.2923 10.0537 10.036 9.79744C9.77977 9.54115 9.77978 9.12564 10.0361 8.86936L11.5444 7.36107C11.7437 7.16172 11.7438 6.83854 11.5444 6.63917C11.5444 6.63915 11.5445 6.63918 11.5444 6.63917L10.0361 5.13081C9.77978 4.87453 9.77978 4.45901 10.0361 4.20273Z" fill="#2e90fa"></path></g></svg>
|
||||
|
After Width: | Height: | Size: 1.4 KiB |
8
dify/api/core/tools/builtin_tool/providers/code/code.py
Normal file
8
dify/api/core/tools/builtin_tool/providers/code/code.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class CodeToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
14
dify/api/core/tools/builtin_tool/providers/code/code.yaml
Normal file
14
dify/api/core/tools/builtin_tool/providers/code/code.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: code
|
||||
label:
|
||||
en_US: Code Interpreter
|
||||
zh_Hans: 代码解释器
|
||||
pt_BR: Interpretador de Código
|
||||
description:
|
||||
en_US: Run a piece of code and get the result back.
|
||||
zh_Hans: 运行一段代码并返回结果。
|
||||
pt_BR: Execute um trecho de código e obtenha o resultado de volta.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- productivity
|
||||
@@ -0,0 +1,33 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class SimpleCode(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke simple code
|
||||
"""
|
||||
|
||||
language = tool_parameters.get("language", CodeLanguage.PYTHON3)
|
||||
code = tool_parameters.get("code", "")
|
||||
|
||||
if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}:
|
||||
raise ValueError(f"Only python3 and javascript are supported, not {language}")
|
||||
|
||||
try:
|
||||
result = CodeExecutor.execute_code(language, "", code)
|
||||
yield self.create_text_message(result)
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
@@ -0,0 +1,51 @@
|
||||
identity:
|
||||
name: simple_code
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Code Interpreter
|
||||
zh_Hans: 代码解释器
|
||||
pt_BR: Interpretador de Código
|
||||
description:
|
||||
human:
|
||||
en_US: Run code and get the result back. When you're using a lower quality model, please make sure there are some tips help LLM to understand how to write the code.
|
||||
zh_Hans: 运行一段代码并返回结果。当您使用较低质量的模型时,请确保有一些提示帮助 LLM 理解如何编写代码。
|
||||
pt_BR: Execute um trecho de código e obtenha o resultado de volta. quando você estiver usando um modelo de qualidade inferior, certifique-se de que existam algumas dicas para ajudar o LLM a entender como escrever o código.
|
||||
llm: A tool for running code and getting the result back. Only native packages are allowed, network/IO operations are disabled. and you must use print() or console.log() to output the result or result will be empty.
|
||||
parameters:
|
||||
- name: language
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Language
|
||||
zh_Hans: 语言
|
||||
pt_BR: Idioma
|
||||
human_description:
|
||||
en_US: The programming language of the code
|
||||
zh_Hans: 代码的编程语言
|
||||
pt_BR: A linguagem de programação do código
|
||||
llm_description: language of the code, only "python3" and "javascript" are supported
|
||||
form: llm
|
||||
options:
|
||||
- value: python3
|
||||
label:
|
||||
en_US: Python3
|
||||
zh_Hans: Python3
|
||||
pt_BR: Python3
|
||||
- value: javascript
|
||||
label:
|
||||
en_US: JavaScript
|
||||
zh_Hans: JavaScript
|
||||
pt_BR: JavaScript
|
||||
- name: code
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Code
|
||||
zh_Hans: 代码
|
||||
pt_BR: Código
|
||||
human_description:
|
||||
en_US: The code to be executed
|
||||
zh_Hans: 要执行的代码
|
||||
pt_BR: O código a ser executado
|
||||
llm_description: code to be executed, only native packages are allowed, network/IO operations are disabled.
|
||||
form: llm
|
||||
@@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M0.666992 8.00008C0.666992 3.94999 3.95024 0.666748 8.00033 0.666748C12.0504 0.666748 15.3337 3.94999 15.3337 8.00008C15.3337 12.0502 12.0504 15.3334 8.00033 15.3334C3.95024 15.3334 0.666992 12.0502 0.666992 8.00008ZM8.66699 4.00008C8.66699 3.63189 8.36852 3.33341 8.00033 3.33341C7.63213 3.33341 7.33366 3.63189 7.33366 4.00008V8.00008C7.33366 8.2526 7.47633 8.48344 7.70218 8.59637L10.3688 9.9297C10.6982 10.0944 11.0986 9.96088 11.2633 9.63156C11.4279 9.30224 11.2945 8.90179 10.9651 8.73713L8.66699 7.58806V4.00008Z" fill="#EC4A0A"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 691 B |
8
dify/api/core/tools/builtin_tool/providers/time/time.py
Normal file
8
dify/api/core/tools/builtin_tool/providers/time/time.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class WikiPediaProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
14
dify/api/core/tools/builtin_tool/providers/time/time.yaml
Normal file
14
dify/api/core/tools/builtin_tool/providers/time/time.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: time
|
||||
label:
|
||||
en_US: CurrentTime
|
||||
zh_Hans: 时间
|
||||
pt_BR: CurrentTime
|
||||
description:
|
||||
en_US: A tool for getting the current time.
|
||||
zh_Hans: 一个用于获取当前时间的工具。
|
||||
pt_BR: A tool for getting the current time.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- utilities
|
||||
@@ -0,0 +1,35 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from pytz import timezone as pytz_timezone
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class CurrentTimeTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
# get timezone
|
||||
tz = tool_parameters.get("timezone", "UTC")
|
||||
fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z"
|
||||
if tz == "UTC":
|
||||
yield self.create_text_message(f"{datetime.now(UTC).strftime(fm)}")
|
||||
return
|
||||
|
||||
try:
|
||||
tz = pytz_timezone(tz)
|
||||
except Exception:
|
||||
yield self.create_text_message(f"Invalid timezone: {tz}")
|
||||
return
|
||||
yield self.create_text_message(f"{datetime.now(tz).strftime(fm)}")
|
||||
@@ -0,0 +1,131 @@
|
||||
identity:
|
||||
name: current_time
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Current Time
|
||||
zh_Hans: 获取当前时间
|
||||
pt_BR: Current Time
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for getting the current time.
|
||||
zh_Hans: 一个用于获取当前时间的工具。
|
||||
pt_BR: A tool for getting the current time.
|
||||
llm: A tool for getting the current time.
|
||||
parameters:
|
||||
- name: format
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Format
|
||||
zh_Hans: 格式
|
||||
pt_BR: Format
|
||||
human_description:
|
||||
en_US: Time format in strftime standard.
|
||||
zh_Hans: strftime 标准的时间格式。
|
||||
pt_BR: Time format in strftime standard.
|
||||
form: form
|
||||
default: "%Y-%m-%d %H:%M:%S"
|
||||
- name: timezone
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: Timezone
|
||||
zh_Hans: 时区
|
||||
pt_BR: Timezone
|
||||
human_description:
|
||||
en_US: Timezone
|
||||
zh_Hans: 时区
|
||||
pt_BR: Timezone
|
||||
form: form
|
||||
default: UTC
|
||||
options:
|
||||
- value: UTC
|
||||
label:
|
||||
en_US: UTC
|
||||
zh_Hans: UTC
|
||||
pt_BR: UTC
|
||||
- value: America/New_York
|
||||
label:
|
||||
en_US: America/New_York
|
||||
zh_Hans: 美洲/纽约
|
||||
pt_BR: America/New_York
|
||||
- value: America/Los_Angeles
|
||||
label:
|
||||
en_US: America/Los_Angeles
|
||||
zh_Hans: 美洲/洛杉矶
|
||||
pt_BR: America/Los_Angeles
|
||||
- value: America/Chicago
|
||||
label:
|
||||
en_US: America/Chicago
|
||||
zh_Hans: 美洲/芝加哥
|
||||
pt_BR: America/Chicago
|
||||
- value: America/Sao_Paulo
|
||||
label:
|
||||
en_US: America/Sao_Paulo
|
||||
zh_Hans: 美洲/圣保罗
|
||||
pt_BR: América/São Paulo
|
||||
- value: Asia/Shanghai
|
||||
label:
|
||||
en_US: Asia/Shanghai
|
||||
zh_Hans: 亚洲/上海
|
||||
pt_BR: Asia/Shanghai
|
||||
- value: Asia/Ho_Chi_Minh
|
||||
label:
|
||||
en_US: Asia/Ho_Chi_Minh
|
||||
zh_Hans: 亚洲/胡志明市
|
||||
pt_BR: Ásia/Ho Chi Minh
|
||||
- value: Asia/Tokyo
|
||||
label:
|
||||
en_US: Asia/Tokyo
|
||||
zh_Hans: 亚洲/东京
|
||||
pt_BR: Asia/Tokyo
|
||||
- value: Asia/Dubai
|
||||
label:
|
||||
en_US: Asia/Dubai
|
||||
zh_Hans: 亚洲/迪拜
|
||||
pt_BR: Asia/Dubai
|
||||
- value: Asia/Kolkata
|
||||
label:
|
||||
en_US: Asia/Kolkata
|
||||
zh_Hans: 亚洲/加尔各答
|
||||
pt_BR: Asia/Kolkata
|
||||
- value: Asia/Seoul
|
||||
label:
|
||||
en_US: Asia/Seoul
|
||||
zh_Hans: 亚洲/首尔
|
||||
pt_BR: Asia/Seoul
|
||||
- value: Asia/Singapore
|
||||
label:
|
||||
en_US: Asia/Singapore
|
||||
zh_Hans: 亚洲/新加坡
|
||||
pt_BR: Asia/Singapore
|
||||
- value: Europe/London
|
||||
label:
|
||||
en_US: Europe/London
|
||||
zh_Hans: 欧洲/伦敦
|
||||
pt_BR: Europe/London
|
||||
- value: Europe/Berlin
|
||||
label:
|
||||
en_US: Europe/Berlin
|
||||
zh_Hans: 欧洲/柏林
|
||||
pt_BR: Europe/Berlin
|
||||
- value: Europe/Moscow
|
||||
label:
|
||||
en_US: Europe/Moscow
|
||||
zh_Hans: 欧洲/莫斯科
|
||||
pt_BR: Europe/Moscow
|
||||
- value: Australia/Sydney
|
||||
label:
|
||||
en_US: Australia/Sydney
|
||||
zh_Hans: 澳大利亚/悉尼
|
||||
pt_BR: Australia/Sydney
|
||||
- value: Pacific/Auckland
|
||||
label:
|
||||
en_US: Pacific/Auckland
|
||||
zh_Hans: 太平洋/奥克兰
|
||||
pt_BR: Pacific/Auckland
|
||||
- value: Africa/Cairo
|
||||
label:
|
||||
en_US: Africa/Cairo
|
||||
zh_Hans: 非洲/开罗
|
||||
pt_BR: Africa/Cairo
|
||||
@@ -0,0 +1,50 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import pytz
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class LocaltimeToTimestampTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Convert localtime to timestamp
|
||||
"""
|
||||
localtime = tool_parameters.get("localtime")
|
||||
timezone = tool_parameters.get("timezone", "Asia/Shanghai")
|
||||
if not timezone:
|
||||
timezone = None
|
||||
time_format = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
timestamp = self.localtime_to_timestamp(localtime, time_format, timezone) # type: ignore
|
||||
if not timestamp:
|
||||
yield self.create_text_message(f"Invalid localtime: {localtime}")
|
||||
return
|
||||
|
||||
yield self.create_text_message(f"{timestamp}")
|
||||
|
||||
# TODO: this method's type is messy
|
||||
@staticmethod
|
||||
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
|
||||
try:
|
||||
local_time = datetime.strptime(localtime, time_format)
|
||||
if local_tz is None:
|
||||
localtime = local_time.astimezone() # type: ignore
|
||||
elif isinstance(local_tz, str):
|
||||
local_tz = pytz.timezone(local_tz)
|
||||
localtime = local_tz.localize(local_time) # type: ignore
|
||||
timestamp = int(localtime.timestamp()) # type: ignore
|
||||
return timestamp
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
@@ -0,0 +1,33 @@
|
||||
identity:
|
||||
name: localtime_to_timestamp
|
||||
author: zhuhao
|
||||
label:
|
||||
en_US: localtime to timestamp
|
||||
zh_Hans: 获取时间戳
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for localtime convert to timestamp
|
||||
zh_Hans: 获取时间戳
|
||||
llm: A tool for localtime convert to timestamp
|
||||
parameters:
|
||||
- name: localtime
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: localtime
|
||||
zh_Hans: 本地时间
|
||||
human_description:
|
||||
en_US: localtime, such as 2024-1-1 0:0:0
|
||||
zh_Hans: 本地时间,比如 2024-1-1 0:0:0
|
||||
- name: timezone
|
||||
type: string
|
||||
required: false
|
||||
form: llm
|
||||
label:
|
||||
en_US: Timezone
|
||||
zh_Hans: 时区
|
||||
human_description:
|
||||
en_US: Timezone, such as Asia/Shanghai
|
||||
zh_Hans: 时区,比如 Asia/Shanghai
|
||||
default: Asia/Shanghai
|
||||
@@ -0,0 +1,49 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import pytz
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class TimestampToLocaltimeTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Convert timestamp to localtime
|
||||
"""
|
||||
timestamp: int = tool_parameters.get("timestamp", 0)
|
||||
timezone = tool_parameters.get("timezone", "Asia/Shanghai")
|
||||
if not timezone:
|
||||
timezone = None
|
||||
time_format = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
locatime = self.timestamp_to_localtime(timestamp, timezone)
|
||||
if not locatime:
|
||||
yield self.create_text_message(f"Invalid timestamp: {timestamp}")
|
||||
return
|
||||
|
||||
localtime_format = locatime.strftime(time_format)
|
||||
|
||||
yield self.create_text_message(f"{localtime_format}")
|
||||
|
||||
@staticmethod
|
||||
def timestamp_to_localtime(timestamp: int, local_tz=None) -> datetime | None:
|
||||
try:
|
||||
if local_tz is None:
|
||||
local_tz = datetime.now().astimezone().tzinfo
|
||||
if isinstance(local_tz, str):
|
||||
local_tz = pytz.timezone(local_tz)
|
||||
local_time = datetime.fromtimestamp(timestamp, local_tz)
|
||||
return local_time
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
@@ -0,0 +1,33 @@
|
||||
identity:
|
||||
name: timestamp_to_localtime
|
||||
author: zhuhao
|
||||
label:
|
||||
en_US: Timestamp to localtime
|
||||
zh_Hans: 时间戳转换
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for timestamp convert to localtime
|
||||
zh_Hans: 时间戳转换
|
||||
llm: A tool for timestamp convert to localtime
|
||||
parameters:
|
||||
- name: timestamp
|
||||
type: number
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Timestamp
|
||||
zh_Hans: 时间戳
|
||||
human_description:
|
||||
en_US: Timestamp
|
||||
zh_Hans: 时间戳
|
||||
- name: timezone
|
||||
type: string
|
||||
required: false
|
||||
form: llm
|
||||
label:
|
||||
en_US: Timezone
|
||||
zh_Hans: 时区
|
||||
human_description:
|
||||
en_US: Timezone, such as Asia/Shanghai
|
||||
zh_Hans: 时区,比如 Asia/Shanghai
|
||||
default: Asia/Shanghai
|
||||
@@ -0,0 +1,53 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import pytz
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class TimezoneConversionTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Convert time to equivalent time zone
|
||||
"""
|
||||
current_time = tool_parameters.get("current_time")
|
||||
current_timezone = tool_parameters.get("current_timezone", "Asia/Shanghai")
|
||||
target_timezone = tool_parameters.get("target_timezone", "Asia/Tokyo")
|
||||
target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore
|
||||
if not target_time:
|
||||
yield self.create_text_message(
|
||||
f"Invalid datetime and timezone: {current_time},{current_timezone},{target_timezone}"
|
||||
)
|
||||
return
|
||||
|
||||
yield self.create_text_message(f"{target_time}")
|
||||
|
||||
@staticmethod
|
||||
def timezone_convert(current_time: str, source_timezone: str, target_timezone: str) -> str:
|
||||
"""
|
||||
Convert a time string from source timezone to target timezone.
|
||||
"""
|
||||
time_format = "%Y-%m-%d %H:%M:%S"
|
||||
try:
|
||||
# get source timezone
|
||||
input_timezone = pytz.timezone(source_timezone)
|
||||
# get target timezone
|
||||
output_timezone = pytz.timezone(target_timezone)
|
||||
local_time = datetime.strptime(current_time, time_format)
|
||||
datetime_with_tz = input_timezone.localize(local_time)
|
||||
# timezone convert
|
||||
converted_datetime = datetime_with_tz.astimezone(output_timezone)
|
||||
return converted_datetime.strftime(time_format)
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
@@ -0,0 +1,44 @@
|
||||
identity:
|
||||
name: timezone_conversion
|
||||
author: zhuhao
|
||||
label:
|
||||
en_US: convert time to equivalent time zone
|
||||
zh_Hans: 时区转换
|
||||
description:
|
||||
human:
|
||||
en_US: A tool to convert time to equivalent time zone
|
||||
zh_Hans: 时区转换
|
||||
llm: A tool to convert time to equivalent time zone
|
||||
parameters:
|
||||
- name: current_time
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: current time
|
||||
zh_Hans: 当前时间
|
||||
human_description:
|
||||
en_US: current time, such as 2024-1-1 0:0:0
|
||||
zh_Hans: 当前时间,比如 2024-1-1 0:0:0
|
||||
- name: current_timezone
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Current Timezone
|
||||
zh_Hans: 当前时区
|
||||
human_description:
|
||||
en_US: Current Timezone, such as Asia/Shanghai
|
||||
zh_Hans: 当前时区,比如 Asia/Shanghai
|
||||
default: Asia/Shanghai
|
||||
- name: target_timezone
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Target Timezone
|
||||
zh_Hans: 目标时区
|
||||
human_description:
|
||||
en_US: Target Timezone, such as Asia/Tokyo
|
||||
zh_Hans: 目标时区,比如 Asia/Tokyo
|
||||
default: Asia/Tokyo
|
||||
@@ -0,0 +1,50 @@
|
||||
import calendar
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class WeekdayTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Calculate the day of the week for a given date
|
||||
"""
|
||||
year = tool_parameters.get("year")
|
||||
month = tool_parameters.get("month")
|
||||
if month is None:
|
||||
raise ValueError("Month is required")
|
||||
day = tool_parameters.get("day")
|
||||
|
||||
date_obj = self.convert_datetime(year, month, day)
|
||||
if not date_obj:
|
||||
yield self.create_text_message(f"Invalid date: Year {year}, Month {month}, Day {day}.")
|
||||
return
|
||||
|
||||
weekday_name = calendar.day_name[date_obj.weekday()]
|
||||
month_name = calendar.month_name[month]
|
||||
readable_date = f"{month_name} {date_obj.day}, {date_obj.year}"
|
||||
yield self.create_text_message(f"{readable_date} is {weekday_name}.")
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime(year, month, day) -> datetime | None:
|
||||
try:
|
||||
# allowed range in datetime module
|
||||
if not (year >= 1 and 1 <= month <= 12 and 1 <= day <= 31):
|
||||
return None
|
||||
|
||||
year = int(year)
|
||||
month = int(month)
|
||||
day = int(day)
|
||||
return datetime(year, month, day)
|
||||
except ValueError:
|
||||
return None
|
||||
@@ -0,0 +1,42 @@
|
||||
identity:
|
||||
name: weekday
|
||||
author: Bowen Liang
|
||||
label:
|
||||
en_US: Weekday Calculator
|
||||
zh_Hans: 星期几计算器
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for calculating the weekday of a given date.
|
||||
zh_Hans: 计算指定日期为星期几的工具。
|
||||
llm: A tool for calculating the weekday of a given date by year, month and day.
|
||||
parameters:
|
||||
- name: year
|
||||
type: number
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Year
|
||||
zh_Hans: 年
|
||||
human_description:
|
||||
en_US: Year
|
||||
zh_Hans: 年
|
||||
- name: month
|
||||
type: number
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Month
|
||||
zh_Hans: 月
|
||||
human_description:
|
||||
en_US: Month
|
||||
zh_Hans: 月
|
||||
- name: day
|
||||
type: number
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: day
|
||||
zh_Hans: 日
|
||||
human_description:
|
||||
en_US: day
|
||||
zh_Hans: 日
|
||||
@@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="17" viewBox="0 0 16 17" fill="none">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M2.6665 1.16667C1.56193 1.16667 0.666504 2.0621 0.666504 3.16667C0.666504 4.27124 1.56193 5.16667 2.6665 5.16667C2.79161 5.16667 2.91403 5.15519 3.03277 5.13321C2.3808 6.09319 1.99984 7.25211 1.99984 8.5C1.99984 9.7479 2.3808 10.9068 3.03277 11.8668C2.91403 11.8448 2.79161 11.8333 2.6665 11.8333C1.56193 11.8333 0.666504 12.7288 0.666504 13.8333C0.666504 14.9379 1.56193 15.8333 2.6665 15.8333C3.77107 15.8333 4.6665 14.9379 4.6665 13.8333C4.6665 13.7082 4.65502 13.5858 4.63304 13.4671C5.59302 14.119 6.75194 14.5 7.99984 14.5C9.24773 14.5 10.4066 14.119 11.3666 13.4671C11.3447 13.5858 11.3332 13.7082 11.3332 13.8333C11.3332 14.9379 12.2286 15.8333 13.3332 15.8333C14.4377 15.8333 15.3332 14.9379 15.3332 13.8333C15.3332 12.7288 14.4377 11.8333 13.3332 11.8333C13.2081 11.8333 13.0856 11.8448 12.9669 11.8668C13.6189 10.9068 13.9998 9.7479 13.9998 8.5C13.9998 7.25211 13.6189 6.09319 12.9669 5.13321C13.0856 5.15519 13.2081 5.16667 13.3332 5.16667C14.4377 5.16667 15.3332 4.27124 15.3332 3.16667C15.3332 2.0621 14.4377 1.16667 13.3332 1.16667C12.2286 1.16667 11.3332 2.0621 11.3332 3.16667C11.3332 3.29177 11.3447 3.41419 11.3666 3.53293C10.4066 2.88097 9.24773 2.50001 7.99984 2.50001C6.75194 2.50001 5.59302 2.88097 4.63304 3.53293C4.65502 3.41419 4.6665 3.29177 4.6665 3.16667C4.6665 2.0621 3.77107 1.16667 2.6665 1.16667ZM3.38043 7.83334C3.63081 6.08287 4.85262 4.64578 6.48223 4.08565C5.79223 5.22099 5.36488 6.50185 5.23815 7.83334H3.38043ZM6.48228 12.9144C4.85264 12.3543 3.63082 10.9172 3.38043 9.16667H5.23815C5.3649 10.4982 5.79226 11.779 6.48228 12.9144ZM12.6192 9.16667C12.3689 10.9168 11.1475 12.3537 9.5183 12.9141C10.2082 11.7788 10.6355 10.498 10.7622 9.16667H12.6192ZM9.51834 4.08596C11.1475 4.64631 12.3689 6.0832 12.6192 7.83334H10.7622C10.6355 6.50197 10.2082 5.22123 9.51834 4.08596ZM9.4218 7.83334C9.27457 6.52262 8.78381 5.27411 8.00019 4.2145C7.21658 5.27411 6.72582 6.52262 6.57859 7.83334H9.4218ZM6.5786 9.16667C6.72583 10.4774 7.21659 11.7259 8.00019 12.7855C8.7838 11.7259 9.27456 10.4774 9.42179 9.16667H6.5786Z" fill="#DD2590"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.2 KiB |
@@ -0,0 +1,39 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.utils.web_reader_tool import get_url
|
||||
|
||||
|
||||
class WebscraperTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
url = tool_parameters.get("url", "")
|
||||
user_agent = tool_parameters.get("user_agent", "")
|
||||
if not url:
|
||||
yield self.create_text_message("Please input url")
|
||||
return
|
||||
|
||||
# get webpage
|
||||
result = get_url(url, user_agent=user_agent)
|
||||
|
||||
if tool_parameters.get("generate_summary"):
|
||||
# summarize and return
|
||||
yield self.create_text_message(self.summary(user_id=user_id, content=result))
|
||||
else:
|
||||
# return full webpage
|
||||
yield self.create_text_message(result)
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
@@ -0,0 +1,60 @@
|
||||
identity:
|
||||
name: webscraper
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Web Scraper
|
||||
zh_Hans: 网页爬虫
|
||||
pt_BR: Web Scraper
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for scraping webpages.
|
||||
zh_Hans: 一个用于爬取网页的工具。
|
||||
pt_BR: A tool for scraping webpages.
|
||||
llm: A tool for scraping webpages. Input should be a URL.
|
||||
parameters:
|
||||
- name: url
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: URL
|
||||
zh_Hans: 网页链接
|
||||
pt_BR: URL
|
||||
human_description:
|
||||
en_US: used for linking to webpages
|
||||
zh_Hans: 用于链接到网页
|
||||
pt_BR: used for linking to webpages
|
||||
llm_description: url for scraping
|
||||
form: llm
|
||||
- name: user_agent
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: User Agent
|
||||
zh_Hans: User Agent
|
||||
pt_BR: User Agent
|
||||
human_description:
|
||||
en_US: used for identifying the browser.
|
||||
zh_Hans: 用于识别浏览器。
|
||||
pt_BR: used for identifying the browser.
|
||||
form: form
|
||||
default: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.1000.0 Safari/537.36
|
||||
- name: generate_summary
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Whether to generate summary
|
||||
zh_Hans: 是否生成摘要
|
||||
human_description:
|
||||
en_US: If true, the crawler will only return the page summary content.
|
||||
zh_Hans: 如果启用,爬虫将仅返回页面摘要内容。
|
||||
form: form
|
||||
options:
|
||||
- value: "true"
|
||||
label:
|
||||
en_US: "Yes"
|
||||
zh_Hans: 是
|
||||
- value: "false"
|
||||
label:
|
||||
en_US: "No"
|
||||
zh_Hans: 否
|
||||
default: "false"
|
||||
@@ -0,0 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class WebscraperProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
Validate credentials
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,15 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: webscraper
|
||||
label:
|
||||
en_US: WebScraper
|
||||
zh_Hans: 网页抓取
|
||||
pt_BR: WebScraper
|
||||
description:
|
||||
en_US: Web Scrapper tool kit is used to scrape web
|
||||
zh_Hans: 一个用于抓取网页的工具。
|
||||
pt_BR: Web Scrapper tool kit is used to scrape web
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- productivity
|
||||
credentials_for_provider: {}
|
||||
148
dify/api/core/tools/builtin_tool/tool.py
Normal file
148
dify/api/core/tools/builtin_tool/tool.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||
|
||||
_SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
|
||||
and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
|
||||
retain the original meaning and keep the key points.
|
||||
however, the text you got is too long, what you got is possible a part of the text.
|
||||
Please summarize the text you got.
|
||||
"""
|
||||
|
||||
|
||||
class BuiltinTool(Tool):
|
||||
"""
|
||||
Builtin tool
|
||||
|
||||
:param meta: the meta data of a tool call processing
|
||||
"""
|
||||
|
||||
def __init__(self, provider: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.provider = provider
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
provider=self.provider,
|
||||
)
|
||||
|
||||
def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]) -> LLMResult:
|
||||
"""
|
||||
invoke model
|
||||
|
||||
:param user_id: the user id
|
||||
:param prompt_messages: the prompt messages
|
||||
:param stop: the stop words
|
||||
:return: the model result
|
||||
"""
|
||||
# invoke model
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tool_type="builtin",
|
||||
tool_name=self.entity.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def get_max_tokens(self) -> int:
|
||||
"""
|
||||
get max tokens
|
||||
|
||||
:return: the max tokens
|
||||
"""
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
)
|
||||
|
||||
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
|
||||
"""
|
||||
get prompt tokens
|
||||
|
||||
:param prompt_messages: the prompt messages
|
||||
:return: the tokens
|
||||
"""
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
|
||||
return ModelInvocationUtils.calculate_tokens(
|
||||
tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
def summary(self, user_id: str, content: str) -> str:
|
||||
max_tokens = self.get_max_tokens()
|
||||
|
||||
if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=content)]) < max_tokens * 0.6:
|
||||
return content
|
||||
|
||||
def get_prompt_tokens(content: str) -> int:
|
||||
return self.get_prompt_tokens(
|
||||
prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)]
|
||||
)
|
||||
|
||||
def summarize(content: str) -> str:
|
||||
summary = self.invoke_model(
|
||||
user_id=user_id,
|
||||
prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)],
|
||||
stop=[],
|
||||
)
|
||||
|
||||
assert isinstance(summary.message.content, str)
|
||||
return summary.message.content
|
||||
|
||||
lines = content.split("\n")
|
||||
new_lines = []
|
||||
# split long line into multiple lines
|
||||
for i in range(len(lines)):
|
||||
line = lines[i]
|
||||
if not line.strip():
|
||||
continue
|
||||
if len(line) < max_tokens * 0.5:
|
||||
new_lines.append(line)
|
||||
elif get_prompt_tokens(line) > max_tokens * 0.7:
|
||||
while get_prompt_tokens(line) > max_tokens * 0.7:
|
||||
new_lines.append(line[: int(max_tokens * 0.5)])
|
||||
line = line[int(max_tokens * 0.5) :]
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line)
|
||||
|
||||
# merge lines into messages with max tokens
|
||||
messages: list[str] = []
|
||||
for j in new_lines:
|
||||
if len(messages) == 0:
|
||||
messages.append(j)
|
||||
else:
|
||||
if len(messages[-1]) + len(j) < max_tokens * 0.5:
|
||||
messages[-1] += j
|
||||
if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7:
|
||||
messages.append(j)
|
||||
else:
|
||||
messages[-1] += j
|
||||
|
||||
summaries = []
|
||||
for i in range(len(messages)):
|
||||
message = messages[i]
|
||||
summary = summarize(message)
|
||||
summaries.append(summary)
|
||||
|
||||
result = "\n".join(summaries)
|
||||
|
||||
if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=result)]) > max_tokens * 0.7:
|
||||
return self.summary(user_id=user_id, content=result)
|
||||
|
||||
return result
|
||||
209
dify/api/core/tools/custom_tool/provider.py
Normal file
209
dify/api/core/tools/custom_tool/provider.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from pydantic import Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.custom_tool.tool import ApiTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolProviderEntity,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class ApiToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
tools: list[ApiTool] = Field(default_factory=list)
|
||||
|
||||
def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str):
|
||||
super().__init__(entity)
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
self.tools = []
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
|
||||
credentials_schema = [
|
||||
ProviderConfig(
|
||||
name="auth_type",
|
||||
required=True,
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
options=[
|
||||
ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ProviderConfig.Option(value="api_key_header", label=I18nObject(en_US="Header", zh_Hans="请求头")),
|
||||
ProviderConfig.Option(
|
||||
value="api_key_query", label=I18nObject(en_US="Query Param", zh_Hans="查询参数")
|
||||
),
|
||||
],
|
||||
default="none",
|
||||
help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
|
||||
)
|
||||
]
|
||||
if auth_type == ApiProviderAuthType.API_KEY_HEADER:
|
||||
credentials_schema = [
|
||||
*credentials_schema,
|
||||
ProviderConfig(
|
||||
name="api_key_header",
|
||||
required=False,
|
||||
default="Authorization",
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
|
||||
),
|
||||
ProviderConfig(
|
||||
name="api_key_value",
|
||||
required=True,
|
||||
type=ProviderConfig.Type.SECRET_INPUT,
|
||||
help=I18nObject(en_US="The api key", zh_Hans="api key 的值"),
|
||||
),
|
||||
ProviderConfig(
|
||||
name="api_key_header_prefix",
|
||||
required=False,
|
||||
default="basic",
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"),
|
||||
options=[
|
||||
ProviderConfig.Option(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
|
||||
ProviderConfig.Option(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
|
||||
ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
|
||||
],
|
||||
),
|
||||
]
|
||||
elif auth_type == ApiProviderAuthType.API_KEY_QUERY:
|
||||
credentials_schema = [
|
||||
*credentials_schema,
|
||||
ProviderConfig(
|
||||
name="api_key_query_param",
|
||||
required=False,
|
||||
default="key",
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
help=I18nObject(
|
||||
en_US="The query parameter name of the api key", zh_Hans="携带 api key 的查询参数名称"
|
||||
),
|
||||
),
|
||||
ProviderConfig(
|
||||
name="api_key_value",
|
||||
required=True,
|
||||
type=ProviderConfig.Type.SECRET_INPUT,
|
||||
help=I18nObject(en_US="The api key", zh_Hans="api key 的值"),
|
||||
),
|
||||
]
|
||||
elif auth_type == ApiProviderAuthType.NONE:
|
||||
pass
|
||||
|
||||
user = db_provider.user
|
||||
user_name = user.name if user else ""
|
||||
|
||||
return ApiToolProviderController(
|
||||
entity=ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
author=user_name,
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
credentials_schema=credentials_schema,
|
||||
plugin_id=None,
|
||||
),
|
||||
provider_id=db_provider.id or "",
|
||||
tenant_id=db_provider.tenant_id or "",
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.API
|
||||
|
||||
def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
|
||||
"""
|
||||
parse tool bundle to tool
|
||||
|
||||
:param tool_bundle: the tool bundle
|
||||
:return: the tool
|
||||
"""
|
||||
return ApiTool(
|
||||
api_bundle=tool_bundle,
|
||||
provider_id=self.provider_id,
|
||||
entity=ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author=tool_bundle.author,
|
||||
name=tool_bundle.operation_id or "default_tool",
|
||||
label=I18nObject(
|
||||
en_US=tool_bundle.operation_id or "default_tool",
|
||||
zh_Hans=tool_bundle.operation_id or "default_tool",
|
||||
),
|
||||
icon=self.entity.identity.icon,
|
||||
provider=self.provider_id,
|
||||
),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""),
|
||||
llm=tool_bundle.summary or "",
|
||||
),
|
||||
parameters=tool_bundle.parameters or [],
|
||||
),
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
)
|
||||
|
||||
def load_bundled_tools(self, tools: list[ApiToolBundle]):
|
||||
"""
|
||||
load bundled tools
|
||||
|
||||
:param tools: the bundled tools
|
||||
:return: the tools
|
||||
"""
|
||||
self.tools = [self._parse_tool_bundle(tool) for tool in tools]
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tools(self, tenant_id: str) -> list[ApiTool]:
|
||||
"""
|
||||
fetch tools from database
|
||||
|
||||
:param tenant_id: the tenant id
|
||||
:return: the tools
|
||||
"""
|
||||
if len(self.tools) > 0:
|
||||
return self.tools
|
||||
|
||||
tools: list[ApiTool] = []
|
||||
|
||||
# get tenant api providers
|
||||
db_providers = db.session.scalars(
|
||||
select(ApiToolProvider).where(
|
||||
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name
|
||||
)
|
||||
).all()
|
||||
|
||||
if db_providers and len(db_providers) != 0:
|
||||
for db_provider in db_providers:
|
||||
for tool in db_provider.tools:
|
||||
assistant_tool = self._parse_tool_bundle(tool)
|
||||
tools.append(assistant_tool)
|
||||
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> ApiTool:
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
:return: the tool
|
||||
"""
|
||||
if self.tools is None:
|
||||
self.get_tools(self.tenant_id)
|
||||
|
||||
for tool in self.tools:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
raise ValueError(f"tool {tool_name} not found")
|
||||
411
dify/api/core/tools/custom_tool/tool.py
Normal file
411
dify/api/core/tools/custom_tool/tool.py
Normal file
@@ -0,0 +1,411 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from os import getenv
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
||||
from core.file.file_manager import download
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
|
||||
API_TOOL_DEFAULT_TIMEOUT = (
|
||||
int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")),
|
||||
int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedResponse:
|
||||
"""Represents a parsed HTTP response with type information"""
|
||||
|
||||
content: Union[str, dict]
|
||||
is_json: bool
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Convert response to string format for credential validation"""
|
||||
if isinstance(self.content, dict):
|
||||
return json.dumps(self.content, ensure_ascii=False)
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class ApiTool(Tool):
|
||||
"""
|
||||
Api tool
|
||||
"""
|
||||
|
||||
def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime, provider_id: str):
|
||||
super().__init__(entity, runtime)
|
||||
self.api_bundle = api_bundle
|
||||
self.provider_id = provider_id
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime):
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
:return: the new tool
|
||||
"""
|
||||
if self.api_bundle is None:
|
||||
raise ValueError("api_bundle is required")
|
||||
return self.__class__(
|
||||
entity=self.entity,
|
||||
api_bundle=self.api_bundle.model_copy(),
|
||||
runtime=runtime,
|
||||
provider_id=self.provider_id,
|
||||
)
|
||||
|
||||
def validate_credentials(
|
||||
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
validate the credentials for Api tool
|
||||
"""
|
||||
# assemble validate request and request parameters
|
||||
headers = self.assembling_request(parameters)
|
||||
|
||||
if format_only:
|
||||
return ""
|
||||
|
||||
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
|
||||
# validate response
|
||||
parsed_response = self.validate_and_parse_response(response)
|
||||
# For credential validation, always return as string
|
||||
return parsed_response.to_string()
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.API
|
||||
|
||||
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
headers = {}
|
||||
if self.runtime is None:
|
||||
raise ToolProviderCredentialValidationError("runtime not initialized")
|
||||
|
||||
credentials = self.runtime.credentials or {}
|
||||
if "auth_type" not in credentials:
|
||||
raise ToolProviderCredentialValidationError("Missing auth_type")
|
||||
|
||||
if credentials["auth_type"] in ("api_key_header", "api_key"): # backward compatibility:
|
||||
api_key_header = "Authorization"
|
||||
|
||||
if "api_key_header" in credentials:
|
||||
api_key_header = credentials["api_key_header"]
|
||||
|
||||
if "api_key_value" not in credentials:
|
||||
raise ToolProviderCredentialValidationError("Missing api_key_value")
|
||||
elif not isinstance(credentials["api_key_value"], str):
|
||||
raise ToolProviderCredentialValidationError("api_key_value must be a string")
|
||||
|
||||
if "api_key_header_prefix" in credentials:
|
||||
api_key_header_prefix = credentials["api_key_header_prefix"]
|
||||
if api_key_header_prefix == "basic" and credentials["api_key_value"]:
|
||||
credentials["api_key_value"] = f"Basic {credentials['api_key_value']}"
|
||||
elif api_key_header_prefix == "bearer" and credentials["api_key_value"]:
|
||||
credentials["api_key_value"] = f"Bearer {credentials['api_key_value']}"
|
||||
elif api_key_header_prefix == "custom":
|
||||
pass
|
||||
|
||||
headers[api_key_header] = credentials["api_key_value"]
|
||||
|
||||
elif credentials["auth_type"] == "api_key_query":
|
||||
# For query parameter authentication, we don't add anything to headers
|
||||
# The query parameter will be added in do_http_request method
|
||||
pass
|
||||
|
||||
needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required]
|
||||
for parameter in needed_parameters:
|
||||
if parameter.required and parameter.name not in parameters:
|
||||
if parameter.default is not None:
|
||||
parameters[parameter.name] = parameter.default
|
||||
else:
|
||||
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
|
||||
|
||||
return headers
|
||||
|
||||
def validate_and_parse_response(self, response: httpx.Response) -> ParsedResponse:
|
||||
"""
|
||||
validate the response and return parsed content with type information
|
||||
|
||||
:return: ParsedResponse with content and is_json flag
|
||||
"""
|
||||
if isinstance(response, httpx.Response):
|
||||
if response.status_code >= 400:
|
||||
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
|
||||
if not response.content:
|
||||
return ParsedResponse(
|
||||
"Empty response from the tool, please check your parameters and try again.", False
|
||||
)
|
||||
|
||||
# Check content type
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
is_json_content_type = "application/json" in content_type
|
||||
|
||||
# Try to parse as JSON
|
||||
try:
|
||||
json_response = response.json()
|
||||
# If content-type indicates JSON, return as JSON object
|
||||
if is_json_content_type:
|
||||
return ParsedResponse(json_response, True)
|
||||
else:
|
||||
# If content-type doesn't indicate JSON, treat as text regardless of content
|
||||
return ParsedResponse(response.text, False)
|
||||
except Exception:
|
||||
# Not valid JSON, return as text
|
||||
return ParsedResponse(response.text, False)
|
||||
else:
|
||||
raise ValueError(f"Invalid response type {type(response)}")
|
||||
|
||||
@staticmethod
|
||||
def get_parameter_value(parameter, parameters):
|
||||
if parameter["name"] in parameters:
|
||||
return parameters[parameter["name"]]
|
||||
elif parameter.get("required", False):
|
||||
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
|
||||
else:
|
||||
return (parameter.get("schema", {}) or {}).get("default", "")
|
||||
|
||||
def do_http_request(
|
||||
self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
method = method.lower()
|
||||
|
||||
params = {}
|
||||
path_params = {}
|
||||
# FIXME: body should be a dict[str, Any] but it changed a lot in this function
|
||||
body: Any = {}
|
||||
cookies = {}
|
||||
files = []
|
||||
|
||||
# Add API key to query parameters if auth_type is api_key_query
|
||||
if self.runtime and self.runtime.credentials:
|
||||
credentials = self.runtime.credentials
|
||||
if credentials.get("auth_type") == "api_key_query":
|
||||
api_key_query_param = credentials.get("api_key_query_param", "key")
|
||||
api_key_value = credentials.get("api_key_value")
|
||||
if api_key_value:
|
||||
params[api_key_query_param] = api_key_value
|
||||
|
||||
# check parameters
|
||||
for parameter in self.api_bundle.openapi.get("parameters", []):
|
||||
value = self.get_parameter_value(parameter, parameters)
|
||||
if parameter["in"] == "path":
|
||||
path_params[parameter["name"]] = value
|
||||
|
||||
elif parameter["in"] == "query":
|
||||
if value != "":
|
||||
params[parameter["name"]] = value
|
||||
|
||||
elif parameter["in"] == "cookie":
|
||||
cookies[parameter["name"]] = value
|
||||
|
||||
elif parameter["in"] == "header":
|
||||
headers[parameter["name"]] = str(value)
|
||||
|
||||
# check if there is a request body and handle it
|
||||
if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None:
|
||||
# handle json request body
|
||||
if "content" in self.api_bundle.openapi["requestBody"]:
|
||||
for content_type in self.api_bundle.openapi["requestBody"]["content"]:
|
||||
headers["Content-Type"] = content_type
|
||||
body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"]
|
||||
|
||||
# handle ref schema
|
||||
if "$ref" in body_schema:
|
||||
ref_path = body_schema["$ref"].split("/")
|
||||
ref_name = ref_path[-1]
|
||||
if (
|
||||
"components" in self.api_bundle.openapi
|
||||
and "schemas" in self.api_bundle.openapi["components"]
|
||||
):
|
||||
if ref_name in self.api_bundle.openapi["components"]["schemas"]:
|
||||
body_schema = self.api_bundle.openapi["components"]["schemas"][ref_name]
|
||||
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
if name in parameters:
|
||||
# multiple file upload: if the type is array and the items have format as binary
|
||||
if property.get("type") == "array" and property.get("items", {}).get("format") == "binary":
|
||||
# parameters[name] should be a list of file objects.
|
||||
for f in parameters[name]:
|
||||
files.append((name, (f.filename, download(f), f.mime_type)))
|
||||
elif property.get("format") == "binary":
|
||||
f = parameters[name]
|
||||
files.append((name, (f.filename, download(f), f.mime_type)))
|
||||
elif "$ref" in property:
|
||||
body[name] = parameters[name]
|
||||
else:
|
||||
# convert type
|
||||
body[name] = self._convert_body_property_type(property, parameters[name])
|
||||
elif name in required:
|
||||
raise ToolParameterValidationError(
|
||||
f"Missing required parameter {name} in operation {self.api_bundle.operation_id}"
|
||||
)
|
||||
elif "default" in property:
|
||||
body[name] = property["default"]
|
||||
else:
|
||||
# omit optional parameters that weren't provided, instead of setting them to None
|
||||
pass
|
||||
break
|
||||
|
||||
# replace path parameters
|
||||
for name, value in path_params.items():
|
||||
url = url.replace(f"{{{name}}}", f"{value}")
|
||||
|
||||
# parse http body data if needed
|
||||
if "Content-Type" in headers:
|
||||
if headers["Content-Type"] == "application/json":
|
||||
body = json.dumps(body)
|
||||
elif headers["Content-Type"] == "application/x-www-form-urlencoded":
|
||||
body = urlencode(body)
|
||||
else:
|
||||
body = body
|
||||
|
||||
# if there is a file upload, remove the Content-Type header
|
||||
# so that httpx can automatically generate the boundary header required for multipart/form-data.
|
||||
# issue: https://github.com/langgenius/dify/issues/13684
|
||||
# reference: https://stackoverflow.com/questions/39280438/fetch-missing-boundary-in-multipart-form-data-post
|
||||
if files:
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
_METHOD_MAP = {
|
||||
"get": ssrf_proxy.get,
|
||||
"head": ssrf_proxy.head,
|
||||
"post": ssrf_proxy.post,
|
||||
"put": ssrf_proxy.put,
|
||||
"delete": ssrf_proxy.delete,
|
||||
"patch": ssrf_proxy.patch,
|
||||
}
|
||||
method_lc = method.lower()
|
||||
if method_lc not in _METHOD_MAP:
|
||||
raise ValueError(f"Invalid http method {method}")
|
||||
response: httpx.Response = _METHOD_MAP[
|
||||
method_lc
|
||||
]( # https://discuss.python.org/t/type-inference-for-function-return-types/42926
|
||||
url,
|
||||
max_retries=0,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
data=body,
|
||||
files=files,
|
||||
timeout=API_TOOL_DEFAULT_TIMEOUT,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return response
|
||||
|
||||
def _convert_body_property_any_of(
|
||||
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
|
||||
):
|
||||
if max_recursive <= 0:
|
||||
raise Exception("Max recursion depth reached")
|
||||
for option in any_of or []:
|
||||
try:
|
||||
if "type" in option:
|
||||
# Attempt to convert the value based on the type.
|
||||
if option["type"] == "integer" or option["type"] == "int":
|
||||
return int(value)
|
||||
elif option["type"] == "number":
|
||||
if "." in str(value):
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
elif option["type"] == "string":
|
||||
return str(value)
|
||||
elif option["type"] == "boolean":
|
||||
if str(value).lower() in {"true", "1"}:
|
||||
return True
|
||||
elif str(value).lower() in {"false", "0"}:
|
||||
return False
|
||||
else:
|
||||
continue # Not a boolean, try next option
|
||||
elif option["type"] == "null" and not value:
|
||||
return None
|
||||
else:
|
||||
continue # Unsupported type, try next option
|
||||
elif "anyOf" in option and isinstance(option["anyOf"], list):
|
||||
# Recursive call to handle nested anyOf
|
||||
return self._convert_body_property_any_of(property, value, option["anyOf"], max_recursive - 1)
|
||||
except ValueError:
|
||||
continue # Conversion failed, try next option
|
||||
# If no option succeeded, you might want to return the value as is or raise an error
|
||||
return value # or raise ValueError(f"Cannot convert value '{value}' to any specified type in anyOf")
|
||||
|
||||
def _convert_body_property_type(self, property: dict[str, Any], value: Any):
|
||||
try:
|
||||
if "type" in property:
|
||||
if property["type"] == "integer" or property["type"] == "int":
|
||||
return int(value)
|
||||
elif property["type"] == "number":
|
||||
# check if it is a float
|
||||
if "." in str(value):
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
elif property["type"] == "string":
|
||||
return str(value)
|
||||
elif property["type"] == "boolean":
|
||||
return bool(value)
|
||||
elif property["type"] == "null":
|
||||
if value is None:
|
||||
return None
|
||||
elif property["type"] == "object" or property["type"] == "array":
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except ValueError:
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return value
|
||||
else:
|
||||
return value
|
||||
else:
|
||||
raise ValueError(f"Invalid type {property['type']} for property {property}")
|
||||
elif "anyOf" in property and isinstance(property["anyOf"], list):
|
||||
return self._convert_body_property_any_of(property, value, property["anyOf"])
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke http request
|
||||
"""
|
||||
response: httpx.Response | str = ""
|
||||
# assemble request
|
||||
headers = self.assembling_request(tool_parameters)
|
||||
|
||||
# do http request
|
||||
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)
|
||||
|
||||
# validate response
|
||||
parsed_response = self.validate_and_parse_response(response)
|
||||
|
||||
# assemble invoke message based on response type
|
||||
if parsed_response.is_json:
|
||||
if isinstance(parsed_response.content, dict):
|
||||
yield self.create_json_message(parsed_response.content)
|
||||
|
||||
# The yield below must be preserved to keep backward compatibility.
|
||||
#
|
||||
# ref: https://github.com/langgenius/dify/pull/23456#issuecomment-3182413088
|
||||
yield self.create_text_message(response.text)
|
||||
else:
|
||||
# Convert to string if needed and create text message
|
||||
text_response = (
|
||||
parsed_response.content if isinstance(parsed_response.content, str) else str(parsed_response.content)
|
||||
)
|
||||
yield self.create_text_message(text_response)
|
||||
132
dify/api/core/tools/entities/api_entities.py
Normal file
132
dify/api/core/tools/entities/api_entities.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class ToolApiEntity(BaseModel):
|
||||
author: str
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: list[ToolParameter] | None = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow", "mcp"] | None
|
||||
|
||||
|
||||
class ToolProviderApiEntity(BaseModel):
|
||||
id: str
|
||||
author: str
|
||||
name: str # identifier
|
||||
description: I18nObject
|
||||
icon: str | Mapping[str, str]
|
||||
icon_dark: str | Mapping[str, str] = ""
|
||||
label: I18nObject # label
|
||||
type: ToolProviderType
|
||||
masked_credentials: Mapping[str, object] = Field(default_factory=dict)
|
||||
original_credentials: Mapping[str, object] = Field(default_factory=dict)
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
plugin_id: str | None = Field(default="", description="The plugin id of the tool")
|
||||
plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the tool")
|
||||
tools: list[ToolApiEntity] = Field(default_factory=list[ToolApiEntity])
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
# MCP
|
||||
server_url: str | None = Field(default="", description="The server url of the tool")
|
||||
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool")
|
||||
|
||||
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
|
||||
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
|
||||
authentication: MCPAuthentication | None = Field(default=None, description="The OAuth config of the MCP tool")
|
||||
is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered")
|
||||
configuration: MCPConfiguration | None = Field(
|
||||
default=None, description="The timeout and sse_read_timeout of the MCP tool"
|
||||
)
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def convert_none_to_empty_list(cls, v):
|
||||
return v if v is not None else []
|
||||
|
||||
def to_dict(self):
|
||||
# -------------
|
||||
# overwrite tool parameter types for temp fix
|
||||
tools = jsonable_encoder(self.tools)
|
||||
for tool in tools:
|
||||
if tool.get("parameters"):
|
||||
for parameter in tool.get("parameters"):
|
||||
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES:
|
||||
parameter["type"] = "files"
|
||||
if parameter.get("input_schema") is None:
|
||||
parameter.pop("input_schema", None)
|
||||
# -------------
|
||||
optional_fields = self.optional_field("server_url", self.server_url)
|
||||
if self.type == ToolProviderType.MCP:
|
||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||
optional_fields.update(
|
||||
self.optional_field(
|
||||
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
|
||||
)
|
||||
)
|
||||
optional_fields.update(
|
||||
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
|
||||
)
|
||||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
"name": self.name,
|
||||
"plugin_id": self.plugin_id,
|
||||
"plugin_unique_identifier": self.plugin_unique_identifier,
|
||||
"description": self.description.to_dict(),
|
||||
"icon": self.icon,
|
||||
"icon_dark": self.icon_dark,
|
||||
"label": self.label.to_dict(),
|
||||
"type": self.type.value,
|
||||
"team_credentials": self.masked_credentials,
|
||||
"is_team_authorization": self.is_team_authorization,
|
||||
"allow_delete": self.allow_delete,
|
||||
"tools": tools,
|
||||
"labels": self.labels,
|
||||
**optional_fields,
|
||||
}
|
||||
|
||||
def optional_field(self, key: str, value: Any):
|
||||
"""Return dict with key-value if value is truthy, empty dict otherwise."""
|
||||
return {key: value} if value else {}
|
||||
|
||||
|
||||
class ToolProviderCredentialApiEntity(BaseModel):
|
||||
id: str = Field(description="The unique id of the credential")
|
||||
name: str = Field(description="The name of the credential")
|
||||
provider: str = Field(description="The provider of the credential")
|
||||
credential_type: CredentialType = Field(description="The type of the credential")
|
||||
is_default: bool = Field(
|
||||
default=False, description="Whether the credential is the default credential for the provider in the workspace"
|
||||
)
|
||||
credentials: Mapping[str, object] = Field(description="The credentials of the provider", default_factory=dict)
|
||||
|
||||
|
||||
class ToolProviderCredentialInfoApiEntity(BaseModel):
|
||||
supported_credential_types: list[CredentialType] = Field(
|
||||
description="The supported credential types of the provider"
|
||||
)
|
||||
is_oauth_custom_client_enabled: bool = Field(
|
||||
default=False, description="Whether the OAuth custom client is enabled for the provider"
|
||||
)
|
||||
credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")
|
||||
22
dify/api/core/tools/entities/common_entities.py
Normal file
22
dify/api/core/tools/entities/common_entities.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
"""
|
||||
Model class for i18n object.
|
||||
"""
|
||||
|
||||
en_US: str
|
||||
zh_Hans: str | None = Field(default=None)
|
||||
pt_BR: str | None = Field(default=None)
|
||||
ja_JP: str | None = Field(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _populate_missing_locales(self):
|
||||
self.zh_Hans = self.zh_Hans or self.en_US
|
||||
self.pt_BR = self.pt_BR or self.en_US
|
||||
self.ja_JP = self.ja_JP or self.en_US
|
||||
return self
|
||||
|
||||
def to_dict(self):
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||
1
dify/api/core/tools/entities/constants.py
Normal file
1
dify/api/core/tools/entities/constants.py
Normal file
@@ -0,0 +1 @@
|
||||
TOOL_SELECTOR_MODEL_IDENTITY = "__dify__tool_selector__"
|
||||
27
dify/api/core/tools/entities/tool_bundle.py
Normal file
27
dify/api/core/tools/entities/tool_bundle.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
|
||||
|
||||
class ApiToolBundle(BaseModel):
|
||||
"""
|
||||
This class is used to store the schema information of an api based tool.
|
||||
such as the url, the method, the parameters, etc.
|
||||
"""
|
||||
|
||||
# server_url
|
||||
server_url: str
|
||||
# method
|
||||
method: str
|
||||
# summary
|
||||
summary: str | None = None
|
||||
# operation_id
|
||||
operation_id: str | None = None
|
||||
# parameters
|
||||
parameters: list[ToolParameter] | None = None
|
||||
# author
|
||||
author: str
|
||||
# icon
|
||||
icon: str | None = None
|
||||
# openapi operation
|
||||
openapi: dict
|
||||
492
dify/api/core/tools/entities/tool_entities.py
Normal file
492
dify/api/core/tools/entities/tool_entities.py
Normal file
@@ -0,0 +1,492 @@
|
||||
import base64
|
||||
import contextlib
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.parameters import (
|
||||
MCPServerParameterType,
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
PluginParameterType,
|
||||
as_normal_type,
|
||||
cast_parameter_value,
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
|
||||
class ToolLabelEnum(StrEnum):
|
||||
SEARCH = "search"
|
||||
IMAGE = "image"
|
||||
VIDEOS = "videos"
|
||||
WEATHER = "weather"
|
||||
FINANCE = "finance"
|
||||
DESIGN = "design"
|
||||
TRAVEL = "travel"
|
||||
SOCIAL = "social"
|
||||
NEWS = "news"
|
||||
MEDICAL = "medical"
|
||||
PRODUCTIVITY = "productivity"
|
||||
EDUCATION = "education"
|
||||
BUSINESS = "business"
|
||||
ENTERTAINMENT = "entertainment"
|
||||
UTILITIES = "utilities"
|
||||
RAG = "rag"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class ToolProviderType(StrEnum):
|
||||
"""
|
||||
Enum class for tool provider
|
||||
"""
|
||||
|
||||
PLUGIN = auto()
|
||||
BUILT_IN = "builtin"
|
||||
WORKFLOW = auto()
|
||||
API = auto()
|
||||
APP = auto()
|
||||
DATASET_RETRIEVAL = "dataset-retrieval"
|
||||
MCP = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ToolProviderType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderSchemaType(StrEnum):
|
||||
"""
|
||||
Enum class for api provider schema type.
|
||||
"""
|
||||
|
||||
OPENAPI = auto()
|
||||
SWAGGER = auto()
|
||||
OPENAI_PLUGIN = auto()
|
||||
OPENAI_ACTIONS = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderSchemaType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderAuthType(StrEnum):
|
||||
"""
|
||||
Enum class for api provider auth type.
|
||||
"""
|
||||
|
||||
NONE = auto()
|
||||
API_KEY_HEADER = auto()
|
||||
API_KEY_QUERY = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderAuthType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
# 'api_key' deprecated in PR #21656
|
||||
# normalize & tiny alias for backward compatibility
|
||||
v = (value or "").strip().lower()
|
||||
if v == "api_key":
|
||||
v = cls.API_KEY_HEADER
|
||||
|
||||
for mode in cls:
|
||||
if mode.value == v:
|
||||
return mode
|
||||
|
||||
valid = ", ".join(m.value for m in cls)
|
||||
raise ValueError(f"invalid mode value '{value}', expected one of: {valid}")
|
||||
|
||||
|
||||
class ToolInvokeMessage(BaseModel):
|
||||
class TextMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
class JsonMessage(BaseModel):
|
||||
json_object: dict
|
||||
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
|
||||
|
||||
class BlobMessage(BaseModel):
|
||||
blob: bytes
|
||||
|
||||
class BlobChunkMessage(BaseModel):
|
||||
id: str = Field(..., description="The id of the blob")
|
||||
sequence: int = Field(..., description="The sequence of the chunk")
|
||||
total_length: int = Field(..., description="The total length of the blob")
|
||||
blob: bytes = Field(..., description="The blob data of the chunk")
|
||||
end: bool = Field(..., description="Whether the chunk is the last chunk")
|
||||
|
||||
class FileMessage(BaseModel):
|
||||
pass
|
||||
|
||||
class VariableMessage(BaseModel):
|
||||
variable_name: str = Field(..., description="The name of the variable")
|
||||
variable_value: Any = Field(..., description="The value of the variable")
|
||||
stream: bool = Field(default=False, description="Whether the variable is streamed")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def transform_variable_value(cls, values):
|
||||
"""
|
||||
Only basic types and lists are allowed.
|
||||
"""
|
||||
value = values.get("variable_value")
|
||||
if not isinstance(value, dict | list | str | int | float | bool):
|
||||
raise ValueError("Only basic types and lists are allowed.")
|
||||
|
||||
# if stream is true, the value must be a string
|
||||
if values.get("stream"):
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
|
||||
return values
|
||||
|
||||
@field_validator("variable_name", mode="before")
|
||||
@classmethod
|
||||
def transform_variable_name(cls, value: str) -> str:
|
||||
"""
|
||||
The variable name must be a string.
|
||||
"""
|
||||
if value in {"json", "text", "files"}:
|
||||
raise ValueError(f"The variable name '{value}' is reserved.")
|
||||
return value
|
||||
|
||||
class LogMessage(BaseModel):
|
||||
class LogStatus(StrEnum):
|
||||
START = auto()
|
||||
ERROR = auto()
|
||||
SUCCESS = auto()
|
||||
|
||||
id: str
|
||||
label: str = Field(..., description="The label of the log")
|
||||
parent_id: str | None = Field(default=None, description="Leave empty for root log")
|
||||
error: str | None = Field(default=None, description="The error message")
|
||||
status: LogStatus = Field(..., description="The status of the log")
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
metadata: Mapping[str, Any] = Field(default_factory=dict, description="The metadata of the log")
|
||||
|
||||
@field_validator("metadata", mode="before")
|
||||
@classmethod
|
||||
def _normalize_metadata(cls, value: Mapping[str, Any] | None) -> Mapping[str, Any]:
|
||||
return value or {}
|
||||
|
||||
class RetrieverResourceMessage(BaseModel):
|
||||
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
class MessageType(StrEnum):
|
||||
TEXT = auto()
|
||||
IMAGE = auto()
|
||||
LINK = auto()
|
||||
BLOB = auto()
|
||||
JSON = auto()
|
||||
IMAGE_LINK = auto()
|
||||
BINARY_LINK = auto()
|
||||
VARIABLE = auto()
|
||||
FILE = auto()
|
||||
LOG = auto()
|
||||
BLOB_CHUNK = auto()
|
||||
RETRIEVER_RESOURCES = auto()
|
||||
|
||||
type: MessageType = MessageType.TEXT
|
||||
"""
|
||||
plain text, image url or link url
|
||||
"""
|
||||
message: (
|
||||
JsonMessage
|
||||
| TextMessage
|
||||
| BlobChunkMessage
|
||||
| BlobMessage
|
||||
| LogMessage
|
||||
| FileMessage
|
||||
| None
|
||||
| VariableMessage
|
||||
| RetrieverResourceMessage
|
||||
)
|
||||
meta: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("message", mode="before")
|
||||
@classmethod
|
||||
def decode_blob_message(cls, v):
|
||||
if isinstance(v, dict) and "blob" in v:
|
||||
with contextlib.suppress(Exception):
|
||||
v["blob"] = base64.b64decode(v["blob"])
|
||||
return v
|
||||
|
||||
@field_serializer("message")
|
||||
def serialize_message(self, v):
|
||||
if isinstance(v, self.BlobMessage):
|
||||
return {"blob": base64.b64encode(v.blob).decode("utf-8")}
|
||||
return v
|
||||
|
||||
|
||||
class ToolInvokeMessageBinary(BaseModel):
|
||||
mimetype: str = Field(..., description="The mimetype of the binary")
|
||||
url: str = Field(..., description="The url of the binary")
|
||||
file_var: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ToolParameter(PluginParameter):
|
||||
"""
|
||||
Overrides type
|
||||
"""
|
||||
|
||||
class ToolParameterType(StrEnum):
|
||||
"""
|
||||
removes TOOLS_SELECTOR from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = PluginParameterType.STRING
|
||||
NUMBER = PluginParameterType.NUMBER
|
||||
BOOLEAN = PluginParameterType.BOOLEAN
|
||||
SELECT = PluginParameterType.SELECT
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT
|
||||
FILE = PluginParameterType.FILE
|
||||
FILES = PluginParameterType.FILES
|
||||
CHECKBOX = PluginParameterType.CHECKBOX
|
||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR
|
||||
ANY = PluginParameterType.ANY
|
||||
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT
|
||||
|
||||
# MCP object and array type parameters
|
||||
ARRAY = MCPServerParameterType.ARRAY
|
||||
OBJECT = MCPServerParameterType.OBJECT
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
class ToolParameterForm(StrEnum):
|
||||
SCHEMA = auto() # should be set while adding tool
|
||||
FORM = auto() # should be set before invoking tool
|
||||
LLM = auto() # will be set by LLM
|
||||
|
||||
type: ToolParameterType = Field(..., description="The type of the parameter")
|
||||
human_description: I18nObject | None = Field(default=None, description="The description presented to the user")
|
||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: str | None = None
|
||||
# MCP object and array type parameters use this field to store the schema
|
||||
input_schema: dict | None = None
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(
|
||||
cls,
|
||||
name: str,
|
||||
llm_description: str,
|
||||
typ: ToolParameterType,
|
||||
required: bool,
|
||||
options: list[str] | None = None,
|
||||
) -> "ToolParameter":
|
||||
"""
|
||||
get a simple tool parameter
|
||||
|
||||
:param name: the name of the parameter
|
||||
:param llm_description: the description presented to the LLM
|
||||
:param typ: the type of the parameter
|
||||
:param required: if the parameter is required
|
||||
:param options: the options of the parameter
|
||||
"""
|
||||
# convert options to ToolParameterOption
|
||||
if options:
|
||||
option_objs = [
|
||||
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in options
|
||||
]
|
||||
else:
|
||||
option_objs = []
|
||||
|
||||
return cls(
|
||||
name=name,
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
placeholder=None,
|
||||
human_description=I18nObject(en_US="", zh_Hans=""),
|
||||
type=typ,
|
||||
form=cls.ToolParameterForm.LLM,
|
||||
llm_description=llm_description,
|
||||
required=required,
|
||||
options=option_objs,
|
||||
)
|
||||
|
||||
def init_frontend_parameter(self, value: Any):
|
||||
return init_frontend_parameter(self, self.type, value)
|
||||
|
||||
|
||||
class ToolProviderIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
description: I18nObject = Field(..., description="The description of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
icon_dark: str | None = Field(default=None, description="The dark icon of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
tags: list[ToolLabelEnum] | None = Field(
|
||||
default=[],
|
||||
description="The tags of the tool",
|
||||
)
|
||||
|
||||
|
||||
class ToolIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
provider: str = Field(..., description="The provider of the tool")
|
||||
icon: str | None = None
|
||||
|
||||
|
||||
class ToolDescription(BaseModel):
|
||||
human: I18nObject = Field(..., description="The description presented to the user")
|
||||
llm: str = Field(..., description="The description presented to the LLM")
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
identity: ToolIdentity
|
||||
parameters: list[ToolParameter] = Field(default_factory=list[ToolParameter])
|
||||
description: ToolDescription | None = None
|
||||
output_schema: Mapping[str, object] = Field(default_factory=dict)
|
||||
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
|
||||
return v or []
|
||||
|
||||
@field_validator("output_schema", mode="before")
|
||||
@classmethod
|
||||
def _normalize_output_schema(cls, value: Mapping[str, object] | None) -> Mapping[str, object]:
|
||||
return value or {}
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
client_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list[ProviderConfig], description="The schema of the OAuth client"
|
||||
)
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list[ProviderConfig], description="The schema of the OAuth credentials"
|
||||
)
|
||||
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
plugin_id: str | None = None
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list[ProviderConfig])
|
||||
oauth_schema: OAuthSchema | None = None
|
||||
|
||||
|
||||
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
||||
tools: list[ToolEntity] = Field(default_factory=list[ToolEntity])
|
||||
|
||||
|
||||
class WorkflowToolParameterConfiguration(BaseModel):
|
||||
"""
|
||||
Workflow tool configuration
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
description: str = Field(..., description="The description of the parameter")
|
||||
form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
|
||||
|
||||
|
||||
class ToolInvokeMeta(BaseModel):
|
||||
"""
|
||||
Tool invoke meta
|
||||
"""
|
||||
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: str | None = None
|
||||
tool_config: dict | None = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "ToolInvokeMeta":
|
||||
"""
|
||||
Get an empty instance of ToolInvokeMeta
|
||||
"""
|
||||
return cls(time_cost=0.0, error=None, tool_config={})
|
||||
|
||||
@classmethod
|
||||
def error_instance(cls, error: str) -> "ToolInvokeMeta":
|
||||
"""
|
||||
Get an instance of ToolInvokeMeta with error
|
||||
"""
|
||||
return cls(time_cost=0.0, error=error, tool_config={})
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"time_cost": self.time_cost,
|
||||
"error": self.error,
|
||||
"tool_config": self.tool_config,
|
||||
}
|
||||
|
||||
|
||||
class ToolLabel(BaseModel):
|
||||
"""
|
||||
Tool label
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
|
||||
class ToolInvokeFrom(StrEnum):
|
||||
"""
|
||||
Enum class for tool invoke
|
||||
"""
|
||||
|
||||
WORKFLOW = auto()
|
||||
AGENT = auto()
|
||||
PLUGIN = auto()
|
||||
|
||||
|
||||
class ToolSelector(BaseModel):
|
||||
dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
class Parameter(BaseModel):
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter")
|
||||
required: bool = Field(..., description="Whether the parameter is required")
|
||||
description: str = Field(..., description="The description of the parameter")
|
||||
default: Union[int, float, str] | None = None
|
||||
options: list[PluginParameterOption] | None = None
|
||||
|
||||
provider_id: str = Field(..., description="The id of the provider")
|
||||
credential_id: str | None = Field(default=None, description="The id of the credential")
|
||||
tool_name: str = Field(..., description="The name of the tool")
|
||||
tool_description: str = Field(..., description="The description of the tool")
|
||||
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
|
||||
tool_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm")
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return self.model_dump()
|
||||
117
dify/api/core/tools/entities/values.py
Normal file
117
dify/api/core/tools/entities/values.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum
|
||||
|
||||
ICONS = {
|
||||
ToolLabelEnum.SEARCH: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M7.33398 1.3335C10.646 1.3335 13.334 4.0215 13.334 7.3335C13.334 10.6455 10.646 13.3335 7.33398 13.3335C4.02198 13.3335 1.33398 10.6455 1.33398 7.3335C1.33398 4.0215 4.02198 1.3335 7.33398 1.3335ZM7.33398 12.0002C9.91232 12.0002 12.0007 9.91183 12.0007 7.3335C12.0007 4.75516 9.91232 2.66683 7.33398 2.66683C4.75565 2.66683 2.66732 4.75516 2.66732 7.3335C2.66732 9.91183 4.75565 12.0002 7.33398 12.0002ZM12.9909 12.0476L14.8764 13.9332L13.9337 14.876L12.0481 12.9904L12.9909 12.0476Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.IMAGE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M13.0514 9.71752L10.4718 7.13792C10.2115 6.87752 9.78932 6.87752 9.52898 7.13792L4.57721 12.0897C3.4097 11.1113 2.66732 9.64232 2.66732 7.99992C2.66732 5.0544 5.05513 2.66659 8.00065 2.66659C10.9462 2.66659 13.334 5.0544 13.334 7.99992C13.334 8.60085 13.2346 9.17852 13.0514 9.71752ZM5.72683 12.8257L10.0004 8.55212L12.4259 10.9777C11.4668 12.4001 9.84152 13.3331 8.00038 13.3331C7.18632 13.3331 6.41628 13.1511 5.72683 12.8257ZM8.00065 14.6666C11.6825 14.6666 14.6673 11.6818 14.6673 7.99992C14.6673 4.31802 11.6825 1.33325 8.00065 1.33325C4.31875 1.33325 1.33398 4.31802 1.33398 7.99992C1.33398 11.6818 4.31875 14.6666 8.00065 14.6666ZM7.33398 6.66658C7.33398 7.40299 6.73705 7.99992 6.00065 7.99992C5.26427 7.99992 4.66732 7.40299 4.66732 6.66658C4.66732 5.9302 5.26427 5.33325 6.00065 5.33325C6.73705 5.33325 7.33398 5.9302 7.33398 6.66658Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.VIDEOS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00065 13.3333H13.334V14.6666H8.00065C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992C14.6673 9.50072 14.1714 10.8857 13.3345 11.9999H11.5284C12.6356 11.0227 13.334 9.59285 13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333ZM8.00065 6.66658C7.26425 6.66658 6.66732 6.06963 6.66732 5.33325C6.66732 4.59687 7.26425 3.99992 8.00065 3.99992C8.73705 3.99992 9.33398 4.59687 9.33398 5.33325C9.33398 6.06963 8.73705 6.66658 8.00065 6.66658ZM5.33398 9.33325C4.5976 9.33325 4.00065 8.73632 4.00065 7.99992C4.00065 7.26352 4.5976 6.66658 5.33398 6.66658C6.07036 6.66658 6.66732 7.26352 6.66732 7.99992C6.66732 8.73632 6.07036 9.33325 5.33398 9.33325ZM10.6673 9.33325C9.93092 9.33325 9.33398 8.73632 9.33398 7.99992C9.33398 7.26352 9.93092 6.66658 10.6673 6.66658C11.4037 6.66658 12.0007 7.26352 12.0007 7.99992C12.0007 8.73632 11.4037 9.33325 10.6673 9.33325ZM8.00065 11.9999C7.26425 11.9999 6.66732 11.403 6.66732 10.6666C6.66732 9.93018 7.26425 9.33325 8.00065 9.33325C8.73705 9.33325 9.33398 9.93018 9.33398 10.6666C9.33398 11.403 8.73705 11.9999 8.00065 11.9999Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.WEATHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M6.6553 3.37344C7.42088 2.1484 8.78162 1.3335 10.3327 1.3335C12.7259 1.3335 14.666 3.2736 14.666 5.66683C14.666 6.38704 14.4903 7.06623 14.1794 7.66383C14.8894 8.3325 15.3327 9.28123 15.3327 10.3335C15.3327 12.3586 13.6911 14.0002 11.666 14.0002H5.99935C3.05383 14.0002 0.666016 11.6124 0.666016 8.66683C0.666016 5.72131 3.05383 3.3335 5.99935 3.3335C6.22143 3.3335 6.44034 3.34707 6.6553 3.37344ZM8.03628 3.73629C9.37768 4.29108 10.4435 5.37735 10.9711 6.73256C11.1961 6.68943 11.4284 6.66683 11.666 6.66683C12.1561 6.66683 12.6237 6.76296 13.0511 6.93743C13.2317 6.55162 13.3327 6.12102 13.3327 5.66683C13.3327 4.00998 11.9895 2.66683 10.3327 2.66683C9.41115 2.66683 8.58662 3.08236 8.03628 3.73629ZM11.666 12.6668C12.9547 12.6668 13.9993 11.6222 13.9993 10.3335C13.9993 9.04483 12.9547 8.00016 11.666 8.00016C11.013 8.00016 10.4227 8.26836 9.99922 8.70063C9.99928 8.68936 9.99935 8.6781 9.99935 8.66683C9.99935 6.45769 8.20848 4.66683 5.99935 4.66683C3.79021 4.66683 1.99935 6.45769 1.99935 8.66683C1.99935 10.876 3.79021 12.6668 5.99935 12.6668H11.666Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.FINANCE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00262 14.6685C4.32071 14.6685 1.33594 11.6838 1.33594 8.00184C1.33594 4.31997 4.32071 1.33521 8.00262 1.33521C11.6845 1.33521 14.6693 4.31997 14.6693 8.00184C14.6693 11.6838 11.6845 14.6685 8.00262 14.6685ZM8.00262 13.3352C10.9482 13.3352 13.336 10.9474 13.336 8.00184C13.336 5.05635 10.9482 2.66854 8.00262 2.66854C5.05708 2.66854 2.66927 5.05635 2.66927 8.00184C2.66927 10.9474 5.05708 13.3352 8.00262 13.3352ZM5.66927 9.33517H9.33595C9.52002 9.33517 9.66928 9.18597 9.66928 9.00184C9.66928 8.81777 9.52002 8.66851 9.33595 8.66851H6.66928C5.7488 8.66851 5.0026 7.92237 5.0026 7.00184C5.0026 6.08139 5.7488 5.33521 6.66928 5.33521H7.33595V4.00187H8.66928V5.33521H10.336V6.66851H6.66928C6.48518 6.66851 6.33594 6.81777 6.33594 7.00184C6.33594 7.18597 6.48518 7.33517 6.66928 7.33517H9.33595C10.2564 7.33517 11.0026 8.08137 11.0026 9.00184C11.0026 9.92237 10.2564 10.6685 9.33595 10.6685H8.66928V12.0018H7.33595V10.6685H5.66927V9.33517Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.DESIGN: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M4.70152 9.41416L3.2873 10.8284L5.17292 12.714L12.7154 5.17154L10.8298 3.28592L9.41557 4.70013L10.3584 5.64295L9.41557 6.58575L8.47277 5.64295L7.52997 6.58575L8.47277 7.52856L7.52997 8.47136L6.58713 7.52856L5.64433 8.47136L6.58713 9.41416L5.64433 10.357L4.70152 9.41416ZM11.3012 1.87171L14.1296 4.70013C14.39 4.96049 14.39 5.38259 14.1296 5.64295L5.64433 14.1282C5.38397 14.3886 4.96187 14.3886 4.70152 14.1282L1.87309 11.2998C1.61274 11.0394 1.61274 10.6174 1.87309 10.357L10.3584 1.87171C10.6187 1.61136 11.0408 1.61136 11.3012 1.87171ZM9.41557 12.2423L10.3584 11.2995L11.8534 12.7945H12.7962V11.8517L11.3012 10.3567L12.244 9.41383L14.0011 11.171V13.9999H11.1732L9.41557 12.2423ZM3.75861 6.58533L1.87299 4.69971C1.61265 4.43937 1.61265 4.01725 1.87299 3.75691L3.75861 1.87129C4.01896 1.61094 4.44107 1.61094 4.70142 1.87129L6.58704 3.75691L5.64423 4.69971L4.23002 3.2855L3.28721 4.22831L4.70142 5.64253L3.75861 6.58533Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.TRAVEL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M9.44839 2C9.80198 2 10.1411 2.14047 10.3912 2.39053L13.6101 5.60947C13.8602 5.85953 14.0007 6.19866 14.0007 6.55229V11.3333H15.334V12.6667L9.91652 12.6672C9.62032 13.8171 8.57638 14.6667 7.33398 14.6667C6.0916 14.6667 5.04766 13.8171 4.75146 12.6672L2.00065 12.6667C1.63246 12.6667 1.33398 12.3682 1.33398 12V3.33333C1.33398 2.59695 1.93094 2 2.66732 2H9.44839ZM7.33398 10.6667C6.5976 10.6667 6.00065 11.2636 6.00065 12C6.00065 12.7364 6.5976 13.3333 7.33398 13.3333C8.07038 13.3333 8.66732 12.7364 8.66732 12C8.66732 11.2636 8.07038 10.6667 7.33398 10.6667ZM9.44839 3.33333H2.66732V11.3333L4.75128 11.3335C5.04726 10.1833 6.09136 9.33333 7.33398 9.33333C8.57658 9.33333 9.62072 10.1833 9.91665 11.3335L12.6673 11.3333V6.55229L9.44839 3.33333ZM9.33398 4.66667V8.66667H4.00065V4.66667H9.33398ZM8.00065 6H5.33398V7.33333H8.00065V6Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.SOCIAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333C9.09518 13.3333 10.1127 13.0035 10.9594 12.438L11.699 13.5475C10.6408 14.2545 9.36885 14.6666 8.00065 14.6666C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992V8.99992C14.6673 10.2886 13.6227 11.3333 12.334 11.3333C11.5312 11.3333 10.8231 10.9278 10.4032 10.3105C9.79678 10.9409 8.94452 11.3333 8.00065 11.3333C6.1597 11.3333 4.66732 9.84085 4.66732 7.99992C4.66732 6.15897 6.1597 4.66658 8.00065 4.66658C8.75118 4.66658 9.44378 4.91464 10.001 5.33325H11.334V8.99992C11.334 9.55219 11.7817 9.99992 12.334 9.99992C12.8863 9.99992 13.334 9.55219 13.334 8.99992V7.99992ZM8.00065 5.99992C6.89605 5.99992 6.00065 6.89532 6.00065 7.99992C6.00065 9.10452 6.89605 9.99992 8.00065 9.99992C9.10525 9.99992 10.0007 9.10452 10.0007 7.99992C10.0007 6.89532 9.10525 5.99992 8.00065 5.99992Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.NEWS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M10.6673 13.3335V2.66683H2.66732V12.6668C2.66732 13.035 2.9658 13.3335 3.33398 13.3335H10.6673ZM12.6673 14.6668H3.33398C2.22942 14.6668 1.33398 13.7714 1.33398 12.6668V2.00016C1.33398 1.63198 1.63246 1.3335 2.00065 1.3335H11.334C11.7022 1.3335 12.0007 1.63198 12.0007 2.00016V6.66683H14.6673V12.6668C14.6673 13.7714 13.7719 14.6668 12.6673 14.6668ZM12.0007 8.00016V12.6668C12.0007 13.035 12.2991 13.3335 12.6673 13.3335C13.0355 13.3335 13.334 13.035 13.334 12.6668V8.00016H12.0007ZM4.00065 4.00016H8.00065V8.00016H4.00065V4.00016ZM5.33398 5.3335V6.66683H6.66732V5.3335H5.33398ZM4.00065 8.66683H9.33398V10.0002H4.00065V8.66683ZM4.00065 10.6668H9.33398V12.0002H4.00065V10.6668Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.MEDICAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.79747 1.51186L10.9641 5.26464C11.1482 5.5835 11.0389 5.99122 10.7201 6.17532L9.85373 6.67474L10.5207 7.83001L9.366 8.49668L8.699 7.34141L7.83333 7.84201C7.51447 8.02608 7.10673 7.91681 6.92267 7.59794L5.69747 5.47632C4.32922 5.89145 3.33333 7.16268 3.33333 8.66654C3.33333 9.08348 3.40987 9.48248 3.54965 9.85034C4.06613 9.52254 4.67762 9.33321 5.33333 9.33321C6.45605 9.33321 7.44913 9.88828 8.05313 10.7389L13.1787 7.78014L13.8454 8.93488L8.5932 11.9672C8.64133 12.1927 8.66667 12.4267 8.66667 12.6665C8.66667 12.895 8.64367 13.1181 8.59993 13.3337L14 13.3332V14.6665L2.66703 14.6673C2.2482 14.1101 2 13.4173 2 12.6665C2 11.9951 2.19855 11.3699 2.54014 10.8467C2.19517 10.1964 2 9.45428 2 8.66654C2 6.66968 3.25421 4.96575 5.01785 4.29953L4.75598 3.84519C4.38779 3.20747 4.60629 2.39202 5.24402 2.02382L6.97607 1.02382C7.6138 0.655637 8.42927 0.874138 8.79747 1.51186ZM5.33333 10.6665C4.22877 10.6665 3.33333 11.562 3.33333 12.6665C3.33333 12.9003 3.37343 13.1247 3.44711 13.3331H7.21953C7.29327 13.1247 7.33333 12.9003 7.33333 12.6665C7.33333 11.562 6.4379 10.6665 5.33333 10.6665ZM7.64273 2.17852L5.91068 3.17852L7.744 6.35395L9.47607 5.35395L7.64273 2.17852Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.PRODUCTIVITY: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M6.64807 11.9999H9.35062C9.43862 11.1989 9.84742 10.5376 10.5111 9.81499C10.5858 9.73365 11.0652 9.23752 11.1221 9.16665C11.6872 8.46199 11.9993 7.58992 11.9993 6.66659C11.9993 4.45745 10.2085 2.66659 7.99935 2.66659C5.79021 2.66659 3.99935 4.45745 3.99935 6.66659C3.99935 7.58945 4.31118 8.46105 4.87576 9.16552C4.93271 9.23659 5.41322 9.73405 5.48704 9.81445C6.15112 10.5375 6.56004 11.1989 6.64807 11.9999ZM9.33268 13.3333H6.66602V13.9999H9.33268V13.3333ZM3.83532 9.99939C3.10365 9.08639 2.66602 7.92759 2.66602 6.66659C2.66602 3.72107 5.05383 1.33325 7.99935 1.33325C10.9449 1.33325 13.3327 3.72107 13.3327 6.66659C13.3327 7.92825 12.8945 9.08759 12.1622 10.0009C11.7487 10.5165 10.666 11.3333 10.666 12.3333V13.9999C10.666 14.7363 10.0691 15.3333 9.33268 15.3333H6.66602C5.92964 15.3333 5.33268 14.7363 5.33268 13.9999V12.3333C5.33268 11.3333 4.24907 10.5157 3.83532 9.99939ZM8.66602 6.66979H10.3327L7.33268 10.6698V8.00312H5.66602L8.66602 3.99992V6.66979Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.EDUCATION: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M14 2.66683H4.66667C3.93029 2.66683 3.33333 3.26378 3.33333 4.00016C3.33333 4.73654 3.93029 5.3335 4.66667 5.3335H14V14.0002C14 14.3684 13.7015 14.6668 13.3333 14.6668H4.66667C3.19391 14.6668 2 13.4729 2 12.0002V4.00016C2 2.5274 3.19391 1.3335 4.66667 1.3335H13.3333C13.7015 1.3335 14 1.63198 14 2.00016V2.66683ZM3.33333 12.0002C3.33333 12.7366 3.93029 13.3335 4.66667 13.3335H12.6667V6.66683H4.66667C4.18095 6.66683 3.72557 6.53697 3.33333 6.31008V12.0002ZM13.3333 4.66683H4.66667C4.29848 4.66683 4 4.36835 4 4.00016C4 3.63198 4.29848 3.3335 4.66667 3.3335H13.3333V4.66683Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.BUSINESS: """<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 14 14" fill="none">
|
||||
<path d="M3.66732 3.33341V1.33341C3.66732 0.965228 3.9658 0.666748 4.33398 0.666748H9.66732C10.0355 0.666748 10.334 0.965228 10.334 1.33341V3.33341H13.0007C13.3689 3.33341 13.6673 3.63189 13.6673 4.00008V13.3334C13.6673 13.7016 13.3689 14.0001 13.0007 14.0001H1.00065C0.632464 14.0001 0.333984 13.7016 0.333984 13.3334V4.00008C0.333984 3.63189 0.632464 3.33341 1.00065 3.33341H3.66732ZM12.334 8.66675H1.66732V12.6667H12.334V8.66675ZM12.334 4.66675H1.66732V7.33341H3.66732V6.00008H5.00065V7.33341H9.00065V6.00008H10.334V7.33341H12.334V4.66675ZM5.00065 2.00008V3.33341H9.00065V2.00008H5.00065Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.ENTERTAINMENT: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M11.3327 2.66675C13.5418 2.66675 15.3327 4.45761 15.3327 6.66675V9.33342C15.3327 11.5425 13.5418 13.3334 11.3327 13.3334H4.66602C2.45688 13.3334 0.666016 11.5425 0.666016 9.33342V6.66675C0.666016 4.45761 2.45688 2.66675 4.66602 2.66675H11.3327ZM11.3327 4.00008H4.66602C3.23788 4.00008 2.07196 5.12273 2.00262 6.53365L1.99935 6.66675V9.33342C1.99935 10.7615 3.122 11.9275 4.53292 11.9968L4.66602 12.0001H11.3327C12.7608 12.0001 13.9267 10.8774 13.9961 9.46648L13.9993 9.33342V6.66675C13.9993 5.23861 12.8767 4.07269 11.4657 4.00335L11.3327 4.00008ZM6.66602 6.00008V7.33342H7.99935V8.66675H6.66535L6.66602 10.0001H5.33268L5.33202 8.66675H3.99935V7.33342H5.33268V6.00008H6.66602ZM11.9993 8.66675V10.0001H10.666V8.66675H11.9993ZM10.666 6.00008V7.33342H9.33268V6.00008H10.666Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.UTILITIES: """<svg xmlns="http://www.w3.org/2000/svg" width="13" height="15" viewBox="0 0 13 15" fill="none">
|
||||
<path d="M12.3346 0.333252C12.7028 0.333252 13.0013 0.631732 13.0013 0.999919V4.33325C13.0013 4.70144 12.7028 4.99992 12.3346 4.99992H9.0013V13.6666C9.0013 14.0348 8.70284 14.3333 8.33463 14.3333H5.66797C5.29978 14.3333 5.0013 14.0348 5.0013 13.6666V4.99992H1.33464C0.966449 4.99992 0.667969 4.70144 0.667969 4.33325V2.74527C0.667969 2.49276 0.810635 2.26192 1.0365 2.14899L4.66797 0.333252H12.3346ZM9.0013 1.66659H4.98273L2.0013 3.1573V3.66659H6.33464V12.9999H7.66797V3.66659H9.0013V1.66659ZM11.668 1.66659H10.3346V3.66659H11.668V1.66659Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.OTHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00052 0.666748L4.00065 7.33342H12.0007L8.00052 0.666748ZM8.00052 3.25828L9.64572 6.00008H6.35553L8.00052 3.25828ZM4.50065 13.3334C3.48813 13.3334 2.66732 12.5126 2.66732 11.5001C2.66732 10.4875 3.48813 9.66675 4.50065 9.66675C5.51317 9.66675 6.33398 10.4875 6.33398 11.5001C6.33398 12.5126 5.51317 13.3334 4.50065 13.3334ZM4.50065 14.6667C6.24955 14.6667 7.66732 13.249 7.66732 11.5001C7.66732 9.75115 6.24955 8.33342 4.50065 8.33342C2.75175 8.33342 1.33398 9.75115 1.33398 11.5001C1.33398 13.249 2.75175 14.6667 4.50065 14.6667ZM10.0007 10.3334V13.0001H12.6673V10.3334H10.0007ZM8.66732 14.3334V9.00008H14.0007V14.3334H8.66732Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.RAG: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00065 1.3335H9.33398V2.66683H8.00065V1.3335ZM5.33398 1.3335H6.66732V2.66683H5.33398V1.3335ZM3.99935 2.66683C3.99935 2.29864 4.29783 2.00016 4.66602 2.00016H12.3327C12.7009 2.00016 13.0007 2.29864 13.0007 2.66683V13.3335C13.0007 13.7017 12.7009 14.0002 12.3327 14.0002H4.66602C4.29783 14.0002 3.99935 13.7017 3.99935 13.3335V2.66683ZM4.66602 12.6668C4.29783 12.6668 3.99935 12.3683 3.99935 12.0002V10.6668H5.33398V12.0002C5.33398 12.3683 5.0355 12.6668 4.66602 12.6668ZM5.33398 8.66683H6.66732V10.0002H5.33398V8.66683ZM5.33398 6.66683H6.66732V8.00016H5.33398V6.66683ZM3.99935 4.66683H6.66602V6.00016H3.99935V4.66683ZM6.66602 1.3335H12.3327V2.66683H6.66602V1.3335Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
}
|
||||
|
||||
default_tool_label_dict = {
|
||||
ToolLabelEnum.SEARCH: ToolLabel(
|
||||
name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH]
|
||||
),
|
||||
ToolLabelEnum.IMAGE: ToolLabel(
|
||||
name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE]
|
||||
),
|
||||
ToolLabelEnum.VIDEOS: ToolLabel(
|
||||
name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS]
|
||||
),
|
||||
ToolLabelEnum.WEATHER: ToolLabel(
|
||||
name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER]
|
||||
),
|
||||
ToolLabelEnum.FINANCE: ToolLabel(
|
||||
name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE]
|
||||
),
|
||||
ToolLabelEnum.DESIGN: ToolLabel(
|
||||
name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN]
|
||||
),
|
||||
ToolLabelEnum.TRAVEL: ToolLabel(
|
||||
name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL]
|
||||
),
|
||||
ToolLabelEnum.SOCIAL: ToolLabel(
|
||||
name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL]
|
||||
),
|
||||
ToolLabelEnum.NEWS: ToolLabel(
|
||||
name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS]
|
||||
),
|
||||
ToolLabelEnum.MEDICAL: ToolLabel(
|
||||
name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL]
|
||||
),
|
||||
ToolLabelEnum.PRODUCTIVITY: ToolLabel(
|
||||
name="productivity",
|
||||
label=I18nObject(en_US="Productivity", zh_Hans="生产力"),
|
||||
icon=ICONS[ToolLabelEnum.PRODUCTIVITY],
|
||||
),
|
||||
ToolLabelEnum.EDUCATION: ToolLabel(
|
||||
name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION]
|
||||
),
|
||||
ToolLabelEnum.BUSINESS: ToolLabel(
|
||||
name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS]
|
||||
),
|
||||
ToolLabelEnum.ENTERTAINMENT: ToolLabel(
|
||||
name="entertainment",
|
||||
label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"),
|
||||
icon=ICONS[ToolLabelEnum.ENTERTAINMENT],
|
||||
),
|
||||
ToolLabelEnum.UTILITIES: ToolLabel(
|
||||
name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES]
|
||||
),
|
||||
ToolLabelEnum.OTHER: ToolLabel(
|
||||
name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER]
|
||||
),
|
||||
ToolLabelEnum.RAG: ToolLabel(
|
||||
name="rag", label=I18nObject(en_US="RAG", zh_Hans="RAG"), icon=ICONS[ToolLabelEnum.RAG]
|
||||
),
|
||||
}
|
||||
|
||||
default_tool_labels = list(default_tool_label_dict.values())
|
||||
default_tool_label_name_list = [label.name for label in default_tool_labels]
|
||||
41
dify/api/core/tools/errors.py
Normal file
41
dify/api/core/tools/errors.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
|
||||
class ToolProviderNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolParameterValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolProviderCredentialValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolNotSupportedError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolInvokeError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolApiSchemaError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolCredentialPolicyViolationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolEngineInvokeError(Exception):
|
||||
meta: ToolInvokeMeta
|
||||
|
||||
def __init__(self, meta, **kwargs):
|
||||
self.meta = meta
|
||||
super().__init__(**kwargs)
|
||||
156
dify/api/core/tools/mcp_tool/provider.py
Normal file
156
dify/api/core/tools/mcp_tool/provider.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from typing import Any, Self
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.types import Tool as RemoteMCPTool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolProviderEntityWithPlugin,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class MCPToolProviderController(ToolProviderController):
|
||||
def __init__(
|
||||
self,
|
||||
entity: ToolProviderEntityWithPlugin,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
server_url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
):
|
||||
super().__init__(entity)
|
||||
self.entity: ToolProviderEntityWithPlugin = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.provider_id = provider_id
|
||||
self.server_url = server_url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.MCP
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: MCPToolProvider) -> Self:
|
||||
"""
|
||||
from db provider
|
||||
"""
|
||||
# Convert to entity first
|
||||
provider_entity = db_provider.to_entity()
|
||||
return cls.from_entity(provider_entity)
|
||||
|
||||
@classmethod
|
||||
def from_entity(cls, entity: MCPProviderEntity) -> Self:
|
||||
"""
|
||||
create a MCPToolProviderController from a MCPProviderEntity
|
||||
"""
|
||||
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
|
||||
|
||||
tools = [
|
||||
ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author="Anonymous", # Tool level author is not stored
|
||||
name=remote_mcp_tool.name,
|
||||
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
|
||||
provider=entity.provider_id,
|
||||
icon=entity.icon if isinstance(entity.icon, str) else "",
|
||||
),
|
||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(
|
||||
en_US=remote_mcp_tool.description or "", zh_Hans=remote_mcp_tool.description or ""
|
||||
),
|
||||
llm=remote_mcp_tool.description or "",
|
||||
),
|
||||
output_schema=remote_mcp_tool.outputSchema or {},
|
||||
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
|
||||
)
|
||||
for remote_mcp_tool in remote_mcp_tools
|
||||
]
|
||||
if not entity.icon:
|
||||
raise ValueError("Database provider icon is required")
|
||||
return cls(
|
||||
entity=ToolProviderEntityWithPlugin(
|
||||
identity=ToolProviderIdentity(
|
||||
author="Anonymous", # Provider level author is not stored in entity
|
||||
name=entity.name,
|
||||
label=I18nObject(en_US=entity.name, zh_Hans=entity.name),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
icon=entity.icon if isinstance(entity.icon, str) else "",
|
||||
),
|
||||
plugin_id=None,
|
||||
credentials_schema=[],
|
||||
tools=tools,
|
||||
),
|
||||
provider_id=entity.provider_id,
|
||||
tenant_id=entity.tenant_id,
|
||||
server_url=entity.server_url,
|
||||
headers=entity.headers,
|
||||
timeout=entity.timeout,
|
||||
sse_read_timeout=entity.sse_read_timeout,
|
||||
)
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_tool(self, tool_name: str) -> MCPTool:
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
tool_entity = next(
|
||||
(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
|
||||
)
|
||||
|
||||
if not tool_entity:
|
||||
raise ValueError(f"Tool with name {tool_name} not found")
|
||||
|
||||
return MCPTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[MCPTool]:
|
||||
"""
|
||||
get all tools
|
||||
"""
|
||||
return [
|
||||
MCPTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
for tool_entity in self.entity.tools
|
||||
]
|
||||
174
dify/api/core/tools/mcp_tool/tool.py
Normal file
174
dify/api/core/tools/mcp_tool/tool.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
def __init__(
|
||||
self,
|
||||
entity: ToolEntity,
|
||||
runtime: ToolRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
server_url: str,
|
||||
provider_id: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
):
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.server_url = server_url
|
||||
self.provider_id = provider_id
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.MCP
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
result = self.invoke_remote_mcp_tool(tool_parameters)
|
||||
# handle dify tool output
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
yield from self._process_text_content(content)
|
||||
elif isinstance(content, ImageContent):
|
||||
yield self._process_image_content(content)
|
||||
elif isinstance(content, AudioContent):
|
||||
yield self._process_audio_content(content)
|
||||
else:
|
||||
logger.warning("Unsupported content type=%s", type(content))
|
||||
|
||||
# handle MCP structured output
|
||||
if self.entity.output_schema and result.structuredContent:
|
||||
for k, v in result.structuredContent.items():
|
||||
yield self.create_variable_message(k, v)
|
||||
|
||||
def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process text content and yield appropriate messages."""
|
||||
# Check if content looks like JSON before attempting to parse
|
||||
text = content.text.strip()
|
||||
if text and text[0] in ("{", "[") and text[-1] in ("}", "]"):
|
||||
try:
|
||||
content_json = json.loads(text)
|
||||
yield from self._process_json_content(content_json)
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# If not JSON or parsing failed, treat as plain text
|
||||
yield self.create_text_message(content.text)
|
||||
|
||||
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process JSON content based on its type."""
|
||||
if isinstance(content_json, dict):
|
||||
yield self.create_json_message(content_json)
|
||||
elif isinstance(content_json, list):
|
||||
yield from self._process_json_list(content_json)
|
||||
else:
|
||||
# For primitive types (str, int, bool, etc.), convert to string
|
||||
yield self.create_text_message(str(content_json))
|
||||
|
||||
def _process_json_list(self, json_list: list) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process a list of JSON items."""
|
||||
if any(not isinstance(item, dict) for item in json_list):
|
||||
# If the list contains any non-dict item, treat the entire list as a text message.
|
||||
yield self.create_text_message(str(json_list))
|
||||
return
|
||||
|
||||
# Otherwise, process each dictionary as a separate JSON message.
|
||||
for item in json_list:
|
||||
yield self.create_json_message(item)
|
||||
|
||||
def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage:
|
||||
"""Process image content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
|
||||
"""Process audio content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
in mcp tool invoke, if the parameter is empty, it will be set to None
|
||||
"""
|
||||
return {
|
||||
key: value
|
||||
for key, value in parameter.items()
|
||||
if value is not None and not (isinstance(value, str) and value.strip() == "")
|
||||
}
|
||||
|
||||
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
|
||||
headers = self.headers.copy() if self.headers else {}
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
# Step 1: Load provider entity and credentials in a short-lived session
|
||||
# This minimizes database connection hold time
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
|
||||
# Decrypt and prepare all credentials before closing session
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_headers()
|
||||
|
||||
# Try to get existing token and add to headers
|
||||
if not headers:
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
|
||||
# Step 2: Session is now closed, perform network operations without holding database connection
|
||||
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
79
dify/api/core/tools/plugin_tool/provider.py
Normal file
79
dify/api/core/tools/plugin_tool/provider.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
|
||||
|
||||
class PluginToolProviderController(BuiltinToolProviderController):
|
||||
entity: ToolProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
):
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
manager = PluginToolManager()
|
||||
if not manager.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=self.entity.identity.name,
|
||||
credentials=credentials,
|
||||
):
|
||||
raise ToolProviderCredentialValidationError("Invalid credentials")
|
||||
|
||||
def get_tool(self, tool_name: str) -> PluginTool: # type: ignore
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
tool_entity = next(
|
||||
(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
|
||||
)
|
||||
|
||||
if not tool_entity:
|
||||
raise ValueError(f"Tool with name {tool_name} not found")
|
||||
|
||||
return PluginTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[PluginTool]: # type: ignore
|
||||
"""
|
||||
get all tools
|
||||
"""
|
||||
return [
|
||||
PluginTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
for tool_entity in self.entity.tools
|
||||
]
|
||||
85
dify/api/core/tools/plugin_tool/tool.py
Normal file
85
dify/api/core/tools/plugin_tool/tool.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
|
||||
|
||||
class PluginTool(Tool):
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
|
||||
):
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
self.runtime_parameters: list[ToolParameter] | None = None
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
manager = PluginToolManager()
|
||||
|
||||
tool_parameters = convert_parameters_to_plugin_format(tool_parameters)
|
||||
|
||||
yield from manager.invoke(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
tool_provider=self.entity.identity.provider,
|
||||
tool_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
credential_type=self.runtime.credential_type,
|
||||
tool_parameters=tool_parameters,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
|
||||
return PluginTool(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get the runtime parameters
|
||||
"""
|
||||
if not self.entity.has_runtime_parameters:
|
||||
return self.entity.parameters
|
||||
|
||||
if self.runtime_parameters is not None:
|
||||
return self.runtime_parameters
|
||||
|
||||
manager = PluginToolManager()
|
||||
self.runtime_parameters = manager.get_runtime_parameters(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="",
|
||||
provider=self.entity.identity.provider,
|
||||
tool=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
return self.runtime_parameters
|
||||
42
dify/api/core/tools/signature.py
Normal file
42
dify/api/core/tools/signature.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def sign_tool_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
sign file to get a temporary url for plugin access
|
||||
"""
|
||||
# Use internal URL for plugin/tool file access in Docker environments
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
|
||||
def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
365
dify/api/core/tools/tool_engine.py
Normal file
365
dify/api/core/tools/tool_engine.py
Normal file
@@ -0,0 +1,365 @@
|
||||
import contextlib
|
||||
import json
|
||||
from collections.abc import Generator, Iterable
|
||||
from copy import deepcopy
|
||||
from datetime import UTC, datetime
|
||||
from mimetypes import guess_type
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file import FileType
|
||||
from core.file.models import FileTransferMethod
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolInvokeMessage,
|
||||
ToolInvokeMessageBinary,
|
||||
ToolInvokeMeta,
|
||||
ToolParameter,
|
||||
)
|
||||
from core.tools.errors import (
|
||||
ToolEngineInvokeError,
|
||||
ToolInvokeError,
|
||||
ToolNotFoundError,
|
||||
ToolNotSupportedError,
|
||||
ToolParameterValidationError,
|
||||
ToolProviderCredentialValidationError,
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import Message, MessageFile
|
||||
|
||||
|
||||
class ToolEngine:
|
||||
"""
|
||||
Tool runtime engine take care of the tool executions.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def agent_invoke(
|
||||
tool: Tool,
|
||||
tool_parameters: Union[str, dict],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
message: Message,
|
||||
invoke_from: InvokeFrom,
|
||||
agent_tool_callback: DifyAgentCallbackHandler,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""
|
||||
Agent invokes the tool with the given arguments.
|
||||
"""
|
||||
# check if arguments is a string
|
||||
if isinstance(tool_parameters, str):
|
||||
# check if this tool has only one parameter
|
||||
parameters = [
|
||||
parameter
|
||||
for parameter in tool.get_runtime_parameters()
|
||||
if parameter.form == ToolParameter.ToolParameterForm.LLM
|
||||
]
|
||||
if parameters and len(parameters) == 1:
|
||||
tool_parameters = {parameters[0].name: tool_parameters}
|
||||
else:
|
||||
with contextlib.suppress(Exception):
|
||||
tool_parameters = json.loads(tool_parameters)
|
||||
if not isinstance(tool_parameters, dict):
|
||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||
|
||||
try:
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
|
||||
|
||||
messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id)
|
||||
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
|
||||
|
||||
def message_callback(
|
||||
invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]
|
||||
):
|
||||
for message in messages:
|
||||
if isinstance(message, ToolInvokeMeta):
|
||||
invocation_meta_dict["meta"] = message
|
||||
else:
|
||||
yield message
|
||||
|
||||
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=message_callback(invocation_meta_dict, messages),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=message.conversation_id,
|
||||
)
|
||||
|
||||
message_list = list(messages)
|
||||
|
||||
# extract binary data from tool invoke message
|
||||
binary_files = ToolEngine._extract_tool_response_binary_and_text(message_list)
|
||||
# create message file
|
||||
message_files = ToolEngine._create_message_files(
|
||||
tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id
|
||||
)
|
||||
|
||||
plain_text = ToolEngine._convert_tool_response_to_str(message_list)
|
||||
|
||||
meta = invocation_meta_dict["meta"]
|
||||
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_end(
|
||||
tool_name=tool.entity.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=plain_text,
|
||||
message_id=message.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# transform tool invoke message to get LLM friendly message
|
||||
return plain_text, message_files, meta
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = "Please check your tool provider credentials"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
|
||||
error_response = f"there is not a tool named {tool.entity.identity.name}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except ToolParameterValidationError as e:
|
||||
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except ToolInvokeError as e:
|
||||
error_response = f"tool invoke error: {e}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except ToolEngineInvokeError as e:
|
||||
meta = e.meta
|
||||
error_response = f"tool invoke error: {meta.error}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
return error_response, [], meta
|
||||
except Exception as e:
|
||||
error_response = f"unknown error: {e}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
|
||||
return error_response, [], ToolInvokeMeta.error_instance(error_response)
|
||||
|
||||
@staticmethod
|
||||
def generic_invoke(
|
||||
tool: Tool,
|
||||
tool_parameters: dict[str, Any],
|
||||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Workflow invokes the tool with the given arguments.
|
||||
"""
|
||||
try:
|
||||
# hit the callback handler
|
||||
workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
|
||||
|
||||
if isinstance(tool, WorkflowTool):
|
||||
tool.workflow_call_depth = workflow_call_depth + 1
|
||||
|
||||
if tool.runtime and tool.runtime.runtime_parameters:
|
||||
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
||||
|
||||
response = tool.invoke(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
# hit the callback handler
|
||||
response = workflow_tool_callback.on_tool_execution(
|
||||
tool_name=tool.entity.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=response,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
workflow_tool_callback.on_tool_error(e)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _invoke(
|
||||
tool: Tool,
|
||||
tool_parameters: dict,
|
||||
user_id: str,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
|
||||
"""
|
||||
Invoke the tool with the given arguments.
|
||||
"""
|
||||
started_at = datetime.now(UTC)
|
||||
meta = ToolInvokeMeta(
|
||||
time_cost=0.0,
|
||||
error=None,
|
||||
tool_config={
|
||||
"tool_name": tool.entity.identity.name,
|
||||
"tool_provider": tool.entity.identity.provider,
|
||||
"tool_provider_type": tool.tool_provider_type().value,
|
||||
"tool_parameters": deepcopy(tool.runtime.runtime_parameters),
|
||||
"tool_icon": tool.entity.identity.icon,
|
||||
},
|
||||
)
|
||||
try:
|
||||
yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id)
|
||||
except Exception as e:
|
||||
meta.error = str(e)
|
||||
raise ToolEngineInvokeError(meta)
|
||||
finally:
|
||||
ended_at = datetime.now(UTC)
|
||||
meta.time_cost = (ended_at - started_at).total_seconds()
|
||||
yield meta
|
||||
|
||||
@staticmethod
|
||||
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Handle tool response
|
||||
"""
|
||||
parts: list[str] = []
|
||||
json_parts: list[str] = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
parts.append(cast(ToolInvokeMessage.TextMessage, response.message).text)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
parts.append(
|
||||
f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
|
||||
+ " please tell user to check it."
|
||||
)
|
||||
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
parts.append(
|
||||
"image has been created and sent to user already, "
|
||||
+ "you do not need to create it, just tell the user to check it now."
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
json_message = cast(ToolInvokeMessage.JsonMessage, response.message)
|
||||
if json_message.suppress_output:
|
||||
continue
|
||||
json_parts.append(
|
||||
json.dumps(
|
||||
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
parts.append(str(response.message))
|
||||
|
||||
# Add JSON parts, avoiding duplicates from text parts.
|
||||
if json_parts:
|
||||
existing_parts = set(parts)
|
||||
parts.extend(p for p in json_parts if p not in existing_parts)
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_response_binary_and_text(
|
||||
tool_response: list[ToolInvokeMessage],
|
||||
) -> Generator[ToolInvokeMessageBinary, None, None]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
for response in tool_response:
|
||||
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
mimetype = None
|
||||
if not response.meta:
|
||||
raise ValueError("missing meta data")
|
||||
if response.meta.get("mime_type"):
|
||||
mimetype = response.meta.get("mime_type")
|
||||
else:
|
||||
with contextlib.suppress(Exception):
|
||||
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
|
||||
extension = url.suffix
|
||||
guess_type_result, _ = guess_type(f"a{extension}")
|
||||
if guess_type_result:
|
||||
mimetype = guess_type_result
|
||||
|
||||
if not mimetype:
|
||||
mimetype = "image/jpeg"
|
||||
|
||||
yield ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", mimetype),
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
if not response.meta:
|
||||
raise ValueError("missing meta data")
|
||||
|
||||
yield ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "application/octet-stream"),
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# check if there is a mime type in meta
|
||||
if response.meta and "mime_type" in response.meta:
|
||||
yield ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream",
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_message_files(
|
||||
tool_messages: Iterable[ToolInvokeMessageBinary],
|
||||
agent_message: Message,
|
||||
invoke_from: InvokeFrom,
|
||||
user_id: str,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
:return: message file ids
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in tool_messages:
|
||||
if "image" in message.mimetype:
|
||||
file_type = FileType.IMAGE
|
||||
elif "video" in message.mimetype:
|
||||
file_type = FileType.VIDEO
|
||||
elif "audio" in message.mimetype:
|
||||
file_type = FileType.AUDIO
|
||||
elif "text" in message.mimetype or "pdf" in message.mimetype:
|
||||
file_type = FileType.DOCUMENT
|
||||
else:
|
||||
file_type = FileType.CUSTOM
|
||||
|
||||
# extract tool file id from url
|
||||
tool_file_id = message.url.split("/")[-1].split(".")[0]
|
||||
message_file = MessageFile(
|
||||
message_id=agent_message.id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
belongs_to="assistant",
|
||||
url=message.url,
|
||||
upload_file_id=tool_file_id,
|
||||
created_by_role=(
|
||||
CreatorUserRole.ACCOUNT
|
||||
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else CreatorUserRole.END_USER
|
||||
),
|
||||
created_by=user_id,
|
||||
)
|
||||
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(message_file)
|
||||
|
||||
result.append(message_file.id)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return result
|
||||
253
dify/api/core/tools/tool_file_manager.py
Normal file
253
dify/api/core/tools/tool_file_manager.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db as global_db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import MessageFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
|
||||
class ToolFileManager:
|
||||
_engine: Engine
|
||||
|
||||
def __init__(self, engine: Engine | None = None):
|
||||
if engine is None:
|
||||
engine = global_db.engine
|
||||
self._engine = engine
|
||||
|
||||
@staticmethod
|
||||
def sign_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
sign file to get a temporary url for plugin access
|
||||
"""
|
||||
# Use internal URL for plugin/tool file access in Docker environments
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
@staticmethod
|
||||
def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
def create_file_by_raw(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str | None,
|
||||
file_binary: bytes,
|
||||
mimetype: str,
|
||||
filename: str | None = None,
|
||||
) -> ToolFile:
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
unique_filename = f"{unique_name}{extension}"
|
||||
# default just as before
|
||||
present_filename = unique_filename
|
||||
if filename is not None:
|
||||
has_extension = len(filename.split(".")) > 1
|
||||
# Add extension flexibly
|
||||
present_filename = filename if has_extension else f"{filename}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{unique_filename}"
|
||||
storage.save(filepath, file_binary)
|
||||
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
name=present_filename,
|
||||
size=len(file_binary),
|
||||
original_url=None,
|
||||
)
|
||||
|
||||
session.add(tool_file)
|
||||
session.commit()
|
||||
session.refresh(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
def create_file_by_url(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
file_url: str,
|
||||
conversation_id: str | None = None,
|
||||
) -> ToolFile:
|
||||
# try to download image
|
||||
try:
|
||||
response = ssrf_proxy.get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError(f"timeout when downloading file from {file_url}")
|
||||
|
||||
mimetype = (
|
||||
guess_type(file_url)[0]
|
||||
or response.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
or "application/octet-stream"
|
||||
)
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, blob)
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
original_url=file_url,
|
||||
name=filename,
|
||||
size=len(blob),
|
||||
)
|
||||
|
||||
session.add(tool_file)
|
||||
session.commit()
|
||||
|
||||
return tool_file
|
||||
|
||||
def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param id: the id of the file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
ToolFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
|
||||
blob = storage.load_once(tool_file.file_key)
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param id: the id of the file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
message_file: MessageFile | None = (
|
||||
session.query(MessageFile)
|
||||
.where(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
# get tool file id
|
||||
if message_file.url is not None:
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
else:
|
||||
tool_file_id = None
|
||||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
|
||||
blob = storage.load_once(tool_file.file_key)
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param tool_file_id: the id of the tool file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None, None
|
||||
|
||||
stream = storage.load_stream(tool_file.file_key)
|
||||
|
||||
return stream, tool_file
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
from core.file.tool_file_parser import set_tool_file_manager_factory
|
||||
|
||||
|
||||
def _factory() -> ToolFileManager:
|
||||
return ToolFileManager()
|
||||
|
||||
|
||||
set_tool_file_manager_factory(_factory)
|
||||
97
dify/api/core/tools/tool_label_manager.py
Normal file
97
dify/api/core/tools/tool_label_manager.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.entities.values import default_tool_label_name_list
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ToolLabelBinding
|
||||
|
||||
|
||||
class ToolLabelManager:
|
||||
@classmethod
|
||||
def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]:
|
||||
"""
|
||||
Filter tool labels
|
||||
"""
|
||||
tool_labels = [label for label in tool_labels if label in default_tool_label_name_list]
|
||||
return list(set(tool_labels))
|
||||
|
||||
@classmethod
|
||||
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
|
||||
"""
|
||||
Update tool labels
|
||||
"""
|
||||
labels = cls.filter_tool_labels(labels)
|
||||
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id
|
||||
else:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
# delete old labels
|
||||
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()
|
||||
|
||||
# insert new labels
|
||||
for label in labels:
|
||||
db.session.add(
|
||||
ToolLabelBinding(
|
||||
tool_id=provider_id,
|
||||
tool_type=controller.provider_type.value,
|
||||
label_name=label,
|
||||
)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
|
||||
"""
|
||||
Get tool labels
|
||||
"""
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id
|
||||
elif isinstance(controller, BuiltinToolProviderController):
|
||||
return controller.tool_labels
|
||||
else:
|
||||
raise ValueError("Unsupported tool type")
|
||||
stmt = select(ToolLabelBinding.label_name).where(
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
||||
)
|
||||
labels = db.session.scalars(stmt).all()
|
||||
|
||||
return list(labels)
|
||||
|
||||
@classmethod
|
||||
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get tools labels
|
||||
|
||||
:param tool_providers: list of tool providers
|
||||
|
||||
:return: dict of tool labels
|
||||
:key: tool id
|
||||
:value: list of tool labels
|
||||
"""
|
||||
if not tool_providers:
|
||||
return {}
|
||||
|
||||
for controller in tool_providers:
|
||||
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
provider_ids = []
|
||||
for controller in tool_providers:
|
||||
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
||||
provider_ids.append(controller.provider_id)
|
||||
|
||||
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
|
||||
|
||||
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
|
||||
|
||||
for label in labels:
|
||||
tool_labels[label.tool_id].append(label.label_name)
|
||||
|
||||
return tool_labels
|
||||
1046
dify/api/core/tools/tool_manager.py
Normal file
1046
dify/api/core/tools/tool_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
0
dify/api/core/tools/utils/__init__.py
Normal file
0
dify/api/core/tools/utils/__init__.py
Normal file
158
dify/api/core/tools/utils/configuration.py
Normal file
158
dify/api/core/tools/utils/configuration.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
|
||||
|
||||
class ToolParameterConfigurationManager:
|
||||
"""
|
||||
Tool parameter configuration manager
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
tool_runtime: Tool
|
||||
provider_name: str
|
||||
provider_type: ToolProviderType
|
||||
identity_id: str
|
||||
|
||||
def __init__(
|
||||
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.tool_runtime = tool_runtime
|
||||
self.provider_name = provider_name
|
||||
self.provider_type = provider_type
|
||||
self.identity_id = identity_id
|
||||
|
||||
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
deep copy parameters
|
||||
"""
|
||||
return deepcopy(parameters)
|
||||
|
||||
def _merge_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
merge parameters
|
||||
"""
|
||||
# get tool parameters
|
||||
tool_parameters = self.tool_runtime.entity.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = self.tool_runtime.get_runtime_parameters()
|
||||
# override parameters
|
||||
current_parameters = tool_parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
found = False
|
||||
for index, parameter in enumerate(current_parameters):
|
||||
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||
current_parameters[index] = runtime_parameter
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
return current_parameters
|
||||
|
||||
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool parameters
|
||||
|
||||
return a deep copy of parameters with masked values
|
||||
"""
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
if len(parameters[parameter.name]) > 6:
|
||||
parameters[parameter.name] = (
|
||||
parameters[parameter.name][:2]
|
||||
+ "*" * (len(parameters[parameter.name]) - 4)
|
||||
+ parameters[parameter.name][-2:]
|
||||
)
|
||||
else:
|
||||
parameters[parameter.name] = "*" * len(parameters[parameter.name])
|
||||
|
||||
return parameters
|
||||
|
||||
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
encrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with encrypted values
|
||||
"""
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
parameters[parameter.name] = encrypted
|
||||
|
||||
return parameters
|
||||
|
||||
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with decrypted values
|
||||
"""
|
||||
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cached_parameters = cache.get()
|
||||
if cached_parameters:
|
||||
return cached_parameters
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
has_secret_input = False
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
has_secret_input = True
|
||||
with contextlib.suppress(Exception):
|
||||
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
|
||||
if has_secret_input:
|
||||
cache.set(parameters)
|
||||
|
||||
return parameters
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cache.delete()
|
||||
@@ -0,0 +1,200 @@
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class DatasetMultiRetrieverToolInput(BaseModel):
|
||||
query: str = Field(..., description="dataset multi retriever and rerank")
|
||||
|
||||
|
||||
class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
"""Tool for querying multi dataset."""
|
||||
|
||||
name: str = "dataset_"
|
||||
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
|
||||
description: str = "dataset multi retriever and rerank. "
|
||||
dataset_ids: list[str]
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
|
||||
return cls(
|
||||
name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs
|
||||
)
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
threads = []
|
||||
all_documents: list[RagDocument] = []
|
||||
for dataset_id in self.dataset_ids:
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"all_documents": all_documents,
|
||||
"hit_callbacks": self.hit_callbacks,
|
||||
},
|
||||
)
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# do rerank for searched documents
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=self.reranking_provider_name,
|
||||
model_type=ModelType.RERANK,
|
||||
model=self.reranking_model_name,
|
||||
)
|
||||
|
||||
rerank_runner = RerankModelRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(all_documents)
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
if item.metadata and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
)
|
||||
segments = db.session.scalars(document_segment_stmt).all()
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(
|
||||
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
|
||||
)
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
|
||||
else:
|
||||
document_context_list.append(segment.get_sign_content())
|
||||
if self.return_resource:
|
||||
context_list: list[RetrievalSourceMetadata] = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document_stmt = select(Document).where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
document = db.session.scalar(document_stmt)
|
||||
if dataset and document:
|
||||
source = RetrievalSourceMetadata(
|
||||
position=resource_number,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=self.retriever_from,
|
||||
score=document_score_list.get(segment.index_node_id),
|
||||
doc_metadata=document.doc_metadata,
|
||||
)
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source.hit_count = segment.hit_count
|
||||
source.word_count = segment.word_count
|
||||
source.segment_position = segment.position
|
||||
source.index_node_hash = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source.content = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
return ""
|
||||
|
||||
def _retriever(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
all_documents: list,
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler],
|
||||
):
|
||||
with flask_app.app_context():
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
for hit_callback in hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 4,
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
if self.top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model["search_method"],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 4,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
@@ -0,0 +1,31 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
|
||||
|
||||
class DatasetRetrieverBaseTool(BaseModel, ABC):
|
||||
"""Tool for querying a Dataset."""
|
||||
|
||||
name: str = "dataset"
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
tenant_id: str
|
||||
top_k: int = 4
|
||||
score_threshold: float | None = None
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Use the tool."""
|
||||
return self._run(query)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, query: str) -> str:
|
||||
"""Use the tool.
|
||||
|
||||
Add run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
to child implementations to enable tracing,
|
||||
"""
|
||||
@@ -0,0 +1,234 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"reranking_mode": "reranking_model",
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class DatasetRetrieverToolInput(BaseModel):
|
||||
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
|
||||
|
||||
|
||||
class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
"""Tool for querying a Dataset."""
|
||||
|
||||
name: str = "dataset"
|
||||
args_schema: type[BaseModel] = DatasetRetrieverToolInput
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
dataset_id: str
|
||||
user_id: str | None = None
|
||||
retrieve_config: DatasetRetrieveConfigEntity
|
||||
inputs: dict
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = "useful for when you want to answer queries about the " + dataset.name
|
||||
|
||||
description = description.replace("\n", "").replace("\r", "")
|
||||
return cls(
|
||||
name=f"dataset_{dataset.id.replace('-', '_')}",
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
description=description,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
if not dataset:
|
||||
return ""
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
|
||||
[dataset.id],
|
||||
query,
|
||||
self.tenant_id,
|
||||
self.user_id or "unknown",
|
||||
cast(str, self.retrieve_config.metadata_filtering_mode),
|
||||
cast(ModelConfig, self.retrieve_config.metadata_model_config),
|
||||
self.retrieve_config.metadata_filtering_conditions,
|
||||
self.inputs,
|
||||
)
|
||||
if metadata_filter_document_ids:
|
||||
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
|
||||
else:
|
||||
document_ids_filter = None
|
||||
if dataset.provider == "external":
|
||||
results: list[RetrievalDocument] = []
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
external_retrieval_parameters=dataset.retrieval_model,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
for external_document in external_documents:
|
||||
document = RetrievalDocument(
|
||||
page_content=external_document.get("content"),
|
||||
metadata=external_document.get("metadata"),
|
||||
provider="external",
|
||||
)
|
||||
if document.metadata is not None:
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset.id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
# deal with external documents
|
||||
context_list: list[RetrievalSourceMetadata] = []
|
||||
for position, item in enumerate(results, start=1):
|
||||
if item.metadata is not None:
|
||||
source = RetrievalSourceMetadata(
|
||||
position=position,
|
||||
dataset_id=item.metadata.get("dataset_id"),
|
||||
dataset_name=item.metadata.get("dataset_name"),
|
||||
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
document_name=item.metadata.get("title"),
|
||||
data_source_type="external",
|
||||
retriever_from=self.retriever_from,
|
||||
score=item.metadata.get("score"),
|
||||
title=item.metadata.get("title"),
|
||||
content=item.page_content,
|
||||
)
|
||||
context_list.append(source)
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join([item.page_content for item in results]))
|
||||
else:
|
||||
if metadata_condition and not document_ids_filter:
|
||||
return ""
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
if self.top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
reranking_model=retrieval_model.get("reranking_model")
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights"),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
for item in documents:
|
||||
if item.metadata is not None and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
document_context_list: list[DocumentContext] = []
|
||||
records = RetrievalService.format_retrieval_documents(documents)
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
if segment.answer:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
else:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=segment.get_sign_content(),
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
|
||||
if self.return_resource:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset and document:
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=self.retriever_from,
|
||||
score=record.score or 0.0,
|
||||
doc_metadata=document.doc_metadata,
|
||||
)
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source.hit_count = segment.hit_count
|
||||
source.word_count = segment.word_count
|
||||
source.segment_position = segment.position
|
||||
source.index_node_hash = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source.content = segment.content
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
if self.return_resource and retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=lambda x: x.score or 0.0,
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
|
||||
item.position = position # type: ignore
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
return str("\n".join([document_context.content for document_context in document_context_list]))
|
||||
return ""
|
||||
136
dify/api/core/tools/utils/dataset_retriever_tool.py
Normal file
136
dify/api/core/tools/utils/dataset_retriever_tool.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
|
||||
|
||||
class DatasetRetrieverTool(Tool):
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool):
|
||||
super().__init__(entity, runtime)
|
||||
self.retrieval_tool = retrieval_tool
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_tools(
|
||||
tenant_id: str,
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity | None,
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
user_id: str,
|
||||
inputs: dict,
|
||||
) -> list["DatasetRetrieverTool"]:
|
||||
"""
|
||||
get dataset tool
|
||||
"""
|
||||
# check if retrieve_config is valid
|
||||
if dataset_ids is None or len(dataset_ids) == 0:
|
||||
return []
|
||||
if retrieve_config is None:
|
||||
return []
|
||||
|
||||
feature = DatasetRetrieval()
|
||||
|
||||
# save original retrieve strategy, and set retrieve strategy to SINGLE
|
||||
# Agent only support SINGLE mode
|
||||
original_retriever_mode = retrieve_config.retrieve_strategy
|
||||
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
|
||||
retrieval_tools = feature.to_dataset_retriever_tool(
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=retrieve_config,
|
||||
return_resource=return_resource,
|
||||
invoke_from=invoke_from,
|
||||
hit_callback=hit_callback,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
)
|
||||
if retrieval_tools is None or len(retrieval_tools) == 0:
|
||||
return []
|
||||
|
||||
# restore retrieve strategy
|
||||
retrieve_config.retrieve_strategy = original_retriever_mode
|
||||
|
||||
# convert retrieval tools to Tools
|
||||
tools = []
|
||||
for retrieval_tool in retrieval_tools:
|
||||
tool = DatasetRetrieverTool(
|
||||
retrieval_tool=retrieval_tool,
|
||||
entity=ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
|
||||
),
|
||||
parameters=[],
|
||||
description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
|
||||
),
|
||||
runtime=ToolRuntime(tenant_id=tenant_id),
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> list[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="query",
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
human_description=I18nObject(en_US="", zh_Hans=""),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description="Query for the dataset to be used to retrieve the dataset.",
|
||||
required=True,
|
||||
default="",
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
),
|
||||
]
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.DATASET_RETRIEVAL
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke dataset retriever tool
|
||||
"""
|
||||
query = tool_parameters.get("query")
|
||||
if not query:
|
||||
yield self.create_text_message(text="please input query")
|
||||
else:
|
||||
# invoke dataset retriever tool
|
||||
result = self.retrieval_tool.run(query=query)
|
||||
yield self.create_text_message(text=result)
|
||||
|
||||
def validate_credentials(
|
||||
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
|
||||
) -> str | None:
|
||||
"""
|
||||
validate the credentials for dataset retriever tool
|
||||
"""
|
||||
pass
|
||||
32
dify/api/core/tools/utils/encryption.py
Normal file
32
dify/api/core/tools/utils/encryption.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Import generic components from provider_encryption module
|
||||
from core.helper.provider_encryption import (
|
||||
ProviderConfigCache,
|
||||
ProviderConfigEncrypter,
|
||||
create_provider_encrypter,
|
||||
)
|
||||
|
||||
# Re-export for backward compatibility
|
||||
__all__ = [
|
||||
"ProviderConfigCache",
|
||||
"ProviderConfigEncrypter",
|
||||
"create_provider_encrypter",
|
||||
"create_tool_provider_encrypter",
|
||||
]
|
||||
|
||||
# Tool-specific imports
|
||||
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
|
||||
|
||||
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
|
||||
cache = SingletonProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=controller.provider_type.value,
|
||||
provider_identity=controller.entity.identity.name,
|
||||
)
|
||||
encrypt = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
|
||||
provider_config_cache=cache,
|
||||
)
|
||||
return encrypt, cache
|
||||
168
dify/api/core/tools/utils/message_transformer.py
Normal file
168
dify/api/core/tools/utils/message_transformer.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from mimetypes import guess_extension
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
import pytz
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_json_value(v):
|
||||
if isinstance(v, datetime):
|
||||
tz_name = "UTC"
|
||||
if isinstance(current_user, Account) and current_user.timezone is not None:
|
||||
tz_name = current_user.timezone
|
||||
return v.astimezone(pytz.timezone(tz_name)).isoformat()
|
||||
elif isinstance(v, date):
|
||||
return v.isoformat()
|
||||
elif isinstance(v, UUID):
|
||||
return str(v)
|
||||
elif isinstance(v, Decimal):
|
||||
return float(v)
|
||||
elif isinstance(v, bytes):
|
||||
try:
|
||||
return v.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return v.hex()
|
||||
elif isinstance(v, memoryview):
|
||||
return v.tobytes().hex()
|
||||
elif isinstance(v, np.ndarray):
|
||||
return v.tolist()
|
||||
elif isinstance(v, dict):
|
||||
return safe_json_dict(v)
|
||||
elif isinstance(v, list | tuple | set):
|
||||
return [safe_json_value(i) for i in v]
|
||||
else:
|
||||
return v
|
||||
|
||||
|
||||
def safe_json_dict(d: dict):
|
||||
if not isinstance(d, dict):
|
||||
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
|
||||
return {k: safe_json_value(v) for k, v in d.items()}
|
||||
|
||||
|
||||
class ToolFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_tool_invoke_messages(
|
||||
cls,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Transform tool message and handle file download
|
||||
"""
|
||||
for message in messages:
|
||||
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
|
||||
yield message
|
||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(
|
||||
message.message, ToolInvokeMessage.TextMessage
|
||||
):
|
||||
# try to download image
|
||||
try:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
tool_file_manager = ToolFileManager()
|
||||
tool_file = tool_file_manager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=message.message.text,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
url = f"/files/tools/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}"
|
||||
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
except Exception as e:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(
|
||||
text=f"Failed to download image: {message.message.text}: {e}"
|
||||
),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
meta = message.meta or {}
|
||||
|
||||
mimetype = meta.get("mime_type", "application/octet-stream")
|
||||
# get filename from meta
|
||||
filename = meta.get("filename", None)
|
||||
# if message is str, encode it to bytes
|
||||
|
||||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
tool_file_manager = ToolFileManager()
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_binary=message.message.blob,
|
||||
mimetype=mimetype,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
url = cls.get_tool_file_url(tool_file_id=tool_file.id, extension=guess_extension(tool_file.mimetype))
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
meta = message.meta or {}
|
||||
file = meta.get("file", None)
|
||||
if isinstance(file, File):
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file.related_id is not None
|
||||
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
|
||||
if file.type == FileType.IMAGE:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield message
|
||||
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
|
||||
message.message.json_object = safe_json_value(message.message.json_object)
|
||||
yield message
|
||||
else:
|
||||
yield message
|
||||
|
||||
@classmethod
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: str | None) -> str:
|
||||
return f"/files/tools/{tool_file_id}{extension or '.bin'}"
|
||||
167
dify/api/core/tools/utils/model_invocation_utils.py
Normal file
167
dify/api/core/tools/utils/model_invocation_utils.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
|
||||
|
||||
Therefore, a model manager is needed to list/invoke/validate models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from decimal import Decimal
|
||||
from typing import cast
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ToolModelInvoke
|
||||
|
||||
|
||||
class InvokeModelError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ModelInvocationUtils:
|
||||
@staticmethod
|
||||
def get_max_llm_context_tokens(
|
||||
tenant_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
get max llm context tokens of the model
|
||||
"""
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
if not model_instance:
|
||||
raise InvokeModelError("Model not found")
|
||||
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
|
||||
if not schema:
|
||||
raise InvokeModelError("No model schema found")
|
||||
|
||||
max_tokens: int | None = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
|
||||
if max_tokens is None:
|
||||
return 2048
|
||||
|
||||
return max_tokens
|
||||
|
||||
@staticmethod
|
||||
def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
|
||||
"""
|
||||
calculate tokens from prompt messages and model parameters
|
||||
"""
|
||||
|
||||
# get model instance
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
|
||||
|
||||
if not model_instance:
|
||||
raise InvokeModelError("Model not found")
|
||||
|
||||
# get tokens
|
||||
tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return tokens
|
||||
|
||||
@staticmethod
|
||||
def invoke(
|
||||
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
invoke model with parameters in user's own context
|
||||
|
||||
:param user_id: user id
|
||||
:param tenant_id: tenant id, the tenant id of the creator of the tool
|
||||
:param tool_type: tool type
|
||||
:param tool_name: tool name
|
||||
:param prompt_messages: prompt messages
|
||||
:return: AssistantPromptMessage
|
||||
"""
|
||||
|
||||
# get model manager
|
||||
model_manager = ModelManager()
|
||||
# get model instance
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
# get prompt tokens
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
model_parameters = {
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.8,
|
||||
}
|
||||
|
||||
# create tool model invoke
|
||||
tool_model_invoke = ToolModelInvoke(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=model_instance.provider,
|
||||
tool_type=tool_type,
|
||||
tool_name=tool_name,
|
||||
model_parameters=json.dumps(model_parameters),
|
||||
prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
|
||||
model_response="",
|
||||
prompt_tokens=prompt_tokens,
|
||||
answer_tokens=0,
|
||||
answer_unit_price=Decimal(),
|
||||
answer_price_unit=Decimal(),
|
||||
provider_response_latency=0,
|
||||
total_price=Decimal(),
|
||||
currency="USD",
|
||||
)
|
||||
|
||||
db.session.add(tool_model_invoke)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
except InvokeRateLimitError as e:
|
||||
raise InvokeModelError(f"Invoke rate limit error: {e}")
|
||||
except InvokeBadRequestError as e:
|
||||
raise InvokeModelError(f"Invoke bad request error: {e}")
|
||||
except InvokeConnectionError as e:
|
||||
raise InvokeModelError(f"Invoke connection error: {e}")
|
||||
except InvokeAuthorizationError as e:
|
||||
raise InvokeModelError("Invoke authorization error")
|
||||
except InvokeServerUnavailableError as e:
|
||||
raise InvokeModelError(f"Invoke server unavailable error: {e}")
|
||||
except Exception as e:
|
||||
raise InvokeModelError(f"Invoke error: {e}")
|
||||
|
||||
# update tool model invoke
|
||||
tool_model_invoke.model_response = str(response.message.content)
|
||||
if response.usage:
|
||||
tool_model_invoke.answer_tokens = response.usage.completion_tokens
|
||||
tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
|
||||
tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
|
||||
tool_model_invoke.provider_response_latency = response.usage.latency
|
||||
tool_model_invoke.total_price = response.usage.total_price
|
||||
tool_model_invoke.currency = response.usage.currency
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return response
|
||||
453
dify/api/core/tools/utils/parser.py
Normal file
453
dify/api/core/tools/utils/parser.py
Normal file
@@ -0,0 +1,453 @@
|
||||
import re
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from flask import request
|
||||
from yaml import YAMLError, safe_load
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
# set description to extra_info
|
||||
extra_info["description"] = openapi["info"].get("description", "")
|
||||
|
||||
if len(openapi["servers"]) == 0:
|
||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||
|
||||
server_url = openapi["servers"][0]["url"]
|
||||
request_env = request.headers.get("X-Request-Env")
|
||||
if request_env:
|
||||
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
|
||||
server_url = matched_servers[0] if matched_servers else server_url
|
||||
|
||||
# list all interfaces
|
||||
interfaces = []
|
||||
for path, path_item in openapi["paths"].items():
|
||||
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
|
||||
for method in methods:
|
||||
if method in path_item:
|
||||
interfaces.append(
|
||||
{
|
||||
"path": path,
|
||||
"method": method,
|
||||
"operation": path_item[method],
|
||||
}
|
||||
)
|
||||
|
||||
# get all parameters
|
||||
bundles = []
|
||||
for interface in interfaces:
|
||||
# convert parameters
|
||||
parameters = []
|
||||
if "parameters" in interface["operation"]:
|
||||
for i, parameter in enumerate(interface["operation"]["parameters"]):
|
||||
if "$ref" in parameter:
|
||||
root = openapi
|
||||
reference = parameter["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
interface["operation"]["parameters"][i] = root
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
# Handle complex type defaults that are not supported by PluginParameter
|
||||
default_value = None
|
||||
if "schema" in parameter and "default" in parameter["schema"]:
|
||||
default_value = ApiBasedToolSchemaParser._sanitize_default_value(parameter["schema"]["default"])
|
||||
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter["name"],
|
||||
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=parameter.get("required", False),
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=parameter.get("description"),
|
||||
default=default_value,
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
|
||||
if typ:
|
||||
tool_parameter.type = typ
|
||||
|
||||
parameters.append(tool_parameter)
|
||||
# create tool bundle
|
||||
# check if there is a request body
|
||||
if "requestBody" in interface["operation"]:
|
||||
request_body = interface["operation"]["requestBody"]
|
||||
if "content" in request_body:
|
||||
for content_type, content in request_body["content"].items():
|
||||
# if there is a reference, get the reference and overwrite the content
|
||||
if "schema" not in content:
|
||||
continue
|
||||
|
||||
if "$ref" in content["schema"]:
|
||||
# get the reference
|
||||
root = openapi
|
||||
reference = content["schema"]["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
# overwrite the content
|
||||
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
||||
|
||||
# handle allOf reference in schema properties
|
||||
for prop_dict in root.get("properties", {}).values():
|
||||
for item in prop_dict.get("allOf", []):
|
||||
if "$ref" in item:
|
||||
ref_schema = openapi
|
||||
reference = item["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
ref_schema = ref_schema[ref]
|
||||
else:
|
||||
ref_schema = item
|
||||
for key, value in ref_schema.items():
|
||||
if isinstance(value, list):
|
||||
if key not in prop_dict:
|
||||
prop_dict[key] = []
|
||||
# extends list field
|
||||
if isinstance(prop_dict[key], list):
|
||||
prop_dict[key].extend(value)
|
||||
elif key not in prop_dict:
|
||||
# add new field
|
||||
prop_dict[key] = value
|
||||
if "allOf" in prop_dict:
|
||||
del prop_dict["allOf"]
|
||||
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
# Handle complex type defaults that are not supported by PluginParameter
|
||||
default_value = ApiBasedToolSchemaParser._sanitize_default_value(
|
||||
property.get("default", None)
|
||||
)
|
||||
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=default_value,
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
|
||||
parameters.append(tool)
|
||||
|
||||
# check if parameters is duplicated
|
||||
parameters_count = {}
|
||||
for parameter in parameters:
|
||||
if parameter.name not in parameters_count:
|
||||
parameters_count[parameter.name] = 0
|
||||
parameters_count[parameter.name] += 1
|
||||
for name, count in parameters_count.items():
|
||||
if count > 1:
|
||||
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
|
||||
|
||||
# check if there is a operation id, use $path_$method as operation id if not
|
||||
if "operationId" not in interface["operation"]:
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = interface["path"]
|
||||
if interface["path"].startswith("/"):
|
||||
path = interface["path"][1:]
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
|
||||
if not path:
|
||||
path = "<root>"
|
||||
|
||||
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||
|
||||
bundles.append(
|
||||
ApiToolBundle(
|
||||
server_url=server_url + interface["path"],
|
||||
method=interface["method"],
|
||||
summary=interface["operation"]["description"]
|
||||
if "description" in interface["operation"]
|
||||
else interface["operation"].get("summary", None),
|
||||
operation_id=interface["operation"]["operationId"],
|
||||
parameters=parameters,
|
||||
author="",
|
||||
icon=None,
|
||||
openapi=interface["operation"],
|
||||
)
|
||||
)
|
||||
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_default_value(value):
|
||||
"""
|
||||
Sanitize default values for PluginParameter compatibility.
|
||||
Complex types (list, dict) are converted to None to avoid validation errors.
|
||||
|
||||
Args:
|
||||
value: The default value from OpenAPI schema
|
||||
|
||||
Returns:
|
||||
None for complex types (list, dict), otherwise the original value
|
||||
"""
|
||||
if isinstance(value, (list, dict)):
|
||||
return None
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
|
||||
parameter = parameter or {}
|
||||
typ: str | None = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
if "type" in parameter:
|
||||
typ = parameter["type"]
|
||||
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||
typ = parameter["schema"]["type"]
|
||||
|
||||
if typ in {"integer", "number"}:
|
||||
return ToolParameter.ToolParameterType.NUMBER
|
||||
elif typ == "boolean":
|
||||
return ToolParameter.ToolParameterType.BOOLEAN
|
||||
elif typ == "string":
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
elif typ == "array":
|
||||
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
||||
if items and items.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILES
|
||||
else:
|
||||
# For regular arrays, return ARRAY type instead of None
|
||||
return ToolParameter.ToolParameterType.ARRAY
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
yaml: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
:param yaml: the yaml string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
openapi: dict = safe_load(yaml)
|
||||
if openapi is None:
|
||||
raise ToolApiSchemaError("Invalid openapi yaml.")
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(
|
||||
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> dict[str, Any]:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
||||
:param swagger: the swagger dict
|
||||
:return: the openapi dict
|
||||
"""
|
||||
# convert swagger to openapi
|
||||
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
|
||||
|
||||
servers = swagger.get("servers", [])
|
||||
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
converted_openapi: dict[str, Any] = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
"description": info.get("description", "Swagger"),
|
||||
"version": info.get("version", "1.0.0"),
|
||||
},
|
||||
"servers": swagger["servers"],
|
||||
"paths": {},
|
||||
"components": {"schemas": {}},
|
||||
}
|
||||
|
||||
# check paths
|
||||
if "paths" not in swagger or len(swagger["paths"]) == 0:
|
||||
raise ToolApiSchemaError("No paths found in the swagger yaml.")
|
||||
|
||||
# convert paths
|
||||
for path, path_item in swagger["paths"].items():
|
||||
converted_openapi["paths"][path] = {}
|
||||
for method, operation in path_item.items():
|
||||
if "operationId" not in operation:
|
||||
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||
|
||||
if ("summary" not in operation or len(operation["summary"]) == 0) and (
|
||||
"description" not in operation or len(operation["description"]) == 0
|
||||
):
|
||||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
converted_openapi["paths"][path][method] = {
|
||||
"operationId": operation["operationId"],
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": operation.get("parameters", []),
|
||||
"responses": operation.get("responses", {}),
|
||||
}
|
||||
|
||||
if "requestBody" in operation:
|
||||
converted_openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||
|
||||
# convert definitions
|
||||
if "definitions" in swagger:
|
||||
for name, definition in swagger["definitions"].items():
|
||||
converted_openapi["components"]["schemas"][name] = definition
|
||||
|
||||
return converted_openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
json: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
:param json: the json string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
try:
|
||||
openai_plugin = json_loads(json)
|
||||
api = openai_plugin["api"]
|
||||
api_url = api["url"]
|
||||
api_type = api["type"]
|
||||
except JSONDecodeError:
|
||||
raise ToolProviderNotFoundError("Invalid openai plugin json.")
|
||||
|
||||
if api_type != "openapi":
|
||||
raise ToolNotSupportedError("Only openapi is supported now.")
|
||||
|
||||
# get openapi yaml
|
||||
response = httpx.get(
|
||||
api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5
|
||||
)
|
||||
|
||||
try:
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
|
||||
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
|
||||
response.text, extra_info=extra_info, warning=warning
|
||||
)
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
:param content: the content
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: tools bundle, schema_type
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
content = content.strip()
|
||||
loaded_content = None
|
||||
json_error = None
|
||||
yaml_error = None
|
||||
|
||||
try:
|
||||
loaded_content = json_loads(content)
|
||||
except JSONDecodeError as e:
|
||||
json_error = e
|
||||
|
||||
if loaded_content is None:
|
||||
try:
|
||||
loaded_content = safe_load(content)
|
||||
except YAMLError as e:
|
||||
yaml_error = e
|
||||
if loaded_content is None:
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
|
||||
f" yaml error: {str(yaml_error)}"
|
||||
)
|
||||
|
||||
swagger_error = None
|
||||
openapi_error = None
|
||||
openapi_plugin_error = None
|
||||
schema_type = None
|
||||
|
||||
try:
|
||||
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
return openapi, schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
openapi_error = e
|
||||
|
||||
# openai parse error, fallback to swagger
|
||||
try:
|
||||
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.SWAGGER
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
converted_swagger, extra_info=extra_info, warning=warning
|
||||
), schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
swagger_error = e
|
||||
|
||||
# swagger parse error, fallback to openai plugin
|
||||
try:
|
||||
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
||||
)
|
||||
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN
|
||||
except ToolNotSupportedError as e:
|
||||
# maybe it's not plugin at all
|
||||
openapi_plugin_error = e
|
||||
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
|
||||
f" openapi plugin error: {str(openapi_plugin_error)}"
|
||||
)
|
||||
187
dify/api/core/tools/utils/system_oauth_encryption.py
Normal file
187
dify/api/core/tools/utils/system_oauth_encryption.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Random import get_random_bytes
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthEncryptionError(Exception):
|
||||
"""OAuth encryption/decryption specific error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SystemOAuthEncrypter:
|
||||
"""
|
||||
A simple OAuth parameters encrypter using AES-CBC encryption.
|
||||
|
||||
This class provides methods to encrypt and decrypt OAuth parameters
|
||||
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||
"""
|
||||
|
||||
def __init__(self, secret_key: str | None = None):
|
||||
"""
|
||||
Initialize the OAuth encrypter.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Raises:
|
||||
ValueError: If SECRET_KEY is not configured or empty
|
||||
"""
|
||||
secret_key = secret_key or dify_config.SECRET_KEY or ""
|
||||
|
||||
# Generate a fixed 256-bit key using SHA-256
|
||||
self.key = hashlib.sha256(secret_key.encode()).digest()
|
||||
|
||||
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If encryption fails
|
||||
ValueError: If oauth_params is invalid
|
||||
"""
|
||||
|
||||
try:
|
||||
# Generate random IV (16 bytes)
|
||||
iv = get_random_bytes(16)
|
||||
|
||||
# Create AES cipher (CBC mode)
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Encrypt data
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
|
||||
# Combine IV and encrypted data
|
||||
combined = iv + encrypted_data
|
||||
|
||||
# Return base64 encoded string
|
||||
return base64.b64encode(combined).decode()
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
|
||||
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If decryption fails
|
||||
ValueError: If encrypted_data is invalid
|
||||
"""
|
||||
if not isinstance(encrypted_data, str):
|
||||
raise ValueError("encrypted_data must be a string")
|
||||
|
||||
if not encrypted_data:
|
||||
raise ValueError("encrypted_data cannot be empty")
|
||||
|
||||
try:
|
||||
# Base64 decode
|
||||
combined = base64.b64decode(encrypted_data)
|
||||
|
||||
# Check minimum length (IV + at least one AES block)
|
||||
if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
|
||||
raise ValueError("Invalid encrypted data format")
|
||||
|
||||
# Separate IV and encrypted data
|
||||
iv = combined[:16]
|
||||
encrypted_data_bytes = combined[16:]
|
||||
|
||||
# Create AES cipher
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Decrypt data
|
||||
decrypted_data = cipher.decrypt(encrypted_data_bytes)
|
||||
unpadded_data = unpad(decrypted_data, AES.block_size)
|
||||
|
||||
# Parse JSON
|
||||
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
|
||||
if not isinstance(oauth_params, dict):
|
||||
raise ValueError("Decrypted data is not a valid dictionary")
|
||||
|
||||
return oauth_params
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
|
||||
|
||||
# Factory function for creating encrypter instances
|
||||
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
|
||||
"""
|
||||
Create an OAuth encrypter instance.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
"""
|
||||
return SystemOAuthEncrypter(secret_key=secret_key)
|
||||
|
||||
|
||||
# Global encrypter instance (for backward compatibility)
|
||||
_oauth_encrypter: SystemOAuthEncrypter | None = None
|
||||
|
||||
|
||||
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
|
||||
"""
|
||||
Get the global OAuth encrypter instance.
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
"""
|
||||
global _oauth_encrypter
|
||||
if _oauth_encrypter is None:
|
||||
_oauth_encrypter = SystemOAuthEncrypter()
|
||||
return _oauth_encrypter
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
"""
|
||||
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
|
||||
|
||||
|
||||
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
"""
|
||||
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
|
||||
17
dify/api/core/tools/utils/text_processing_utils.py
Normal file
17
dify/api/core/tools/utils/text_processing_utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import re
|
||||
|
||||
|
||||
def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
Remove leading punctuation or symbols from the given text.
|
||||
|
||||
Args:
|
||||
text (str): The input text to process.
|
||||
|
||||
Returns:
|
||||
str: The text with leading punctuation or symbols removed.
|
||||
"""
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||
return re.sub(pattern, "", text)
|
||||
11
dify/api/core/tools/utils/uuid_utils.py
Normal file
11
dify/api/core/tools/utils/uuid_utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import uuid
|
||||
|
||||
|
||||
def is_valid_uuid(uuid_str: str | None) -> bool:
|
||||
if uuid_str is None or len(uuid_str) == 0:
|
||||
return False
|
||||
try:
|
||||
uuid.UUID(uuid_str)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
128
dify/api/core/tools/utils/web_reader_tool.py
Normal file
128
dify/api/core/tools/utils/web_reader_tool.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import mimetypes
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import cloudscraper
|
||||
from readabilipy import simple_json_from_html_string
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor import extract_processor
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
AUTHOR: {author}
|
||||
TEXT:
|
||||
|
||||
{text}
|
||||
"""
|
||||
|
||||
|
||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
||||
return text[cursor : cursor + max_length]
|
||||
|
||||
|
||||
def get_url(url: str, user_agent: str | None = None) -> str:
|
||||
"""Fetch URL and return the contents as a string."""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
||||
" Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
if user_agent:
|
||||
headers["User-Agent"] = user_agent
|
||||
|
||||
main_content_type = None
|
||||
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||
response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
|
||||
|
||||
if response.status_code == 200:
|
||||
# check content-type
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if content_type:
|
||||
main_content_type = response.headers.get("Content-Type").split(";")[0].strip()
|
||||
else:
|
||||
content_disposition = response.headers.get("Content-Disposition", "")
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
extension = re.search(r"\.(\w+)$", filename)
|
||||
if extension:
|
||||
main_content_type = mimetypes.guess_type(filename)[0]
|
||||
|
||||
if main_content_type not in supported_content_types:
|
||||
return f"Unsupported content-type [{main_content_type}] of URL."
|
||||
|
||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
|
||||
|
||||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
elif response.status_code == 403:
|
||||
scraper = cloudscraper.create_scraper()
|
||||
scraper.perform_request = ssrf_proxy.make_request
|
||||
response = scraper.get(url, headers=headers, timeout=(120, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
return f"URL returned status code {response.status_code}."
|
||||
|
||||
# Detect encoding using chardet
|
||||
detected_encoding = chardet.detect(response.content)
|
||||
encoding = detected_encoding["encoding"]
|
||||
if encoding:
|
||||
try:
|
||||
content = response.content.decode(encoding)
|
||||
except (UnicodeDecodeError, TypeError):
|
||||
content = response.text
|
||||
else:
|
||||
content = response.text
|
||||
|
||||
article = extract_using_readabilipy(content)
|
||||
|
||||
if not article.text:
|
||||
return ""
|
||||
|
||||
res = FULL_TEMPLATE.format(
|
||||
title=article.title,
|
||||
author=article.author,
|
||||
text=article.text,
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@dataclass
|
||||
class Article:
|
||||
title: str
|
||||
author: str
|
||||
text: Sequence[dict]
|
||||
|
||||
|
||||
def extract_using_readabilipy(html: str):
|
||||
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
|
||||
article = Article(
|
||||
title=json_article.get("title") or "",
|
||||
author=json_article.get("byline") or "",
|
||||
text=json_article.get("plain_text") or [],
|
||||
)
|
||||
|
||||
return article
|
||||
|
||||
|
||||
def get_image_upload_file_ids(content):
|
||||
pattern = r"!\[image\]\((https?://.*?(file-preview|image-preview))\)"
|
||||
matches = re.findall(pattern, content)
|
||||
image_upload_file_ids = []
|
||||
for match in matches:
|
||||
if match[1] == "file-preview":
|
||||
content_pattern = r"files/([^/]+)/file-preview"
|
||||
else:
|
||||
content_pattern = r"files/([^/]+)/image-preview"
|
||||
content_match = re.search(content_pattern, match[0])
|
||||
if content_match:
|
||||
image_upload_file_id = content_match.group(1)
|
||||
image_upload_file_ids.append(image_upload_file_id)
|
||||
return image_upload_file_ids
|
||||
43
dify/api/core/tools/utils/workflow_configuration_sync.py
Normal file
43
dify/api/core/tools/utils/workflow_configuration_sync.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
):
|
||||
"""
|
||||
check is synced
|
||||
|
||||
raise ValueError if not synced
|
||||
"""
|
||||
variable_names = [variable.variable for variable in variables]
|
||||
|
||||
if len(tool_configurations) != len(variables):
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
|
||||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
33
dify/api/core/tools/utils/yaml_utils.py
Normal file
33
dify/api/core/tools/utils/yaml_utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load_yaml_file(*, file_path: str):
|
||||
if not file_path or not Path(file_path).exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, encoding="utf-8") as yaml_file:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content
|
||||
except Exception as e:
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def load_yaml_file_cached(file_path: str) -> Any:
|
||||
"""
|
||||
Cached version of load_yaml_file for static configuration files.
|
||||
Only use for files that don't change during runtime (e.g., position files)
|
||||
|
||||
:param file_path: the path of the YAML file
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
return _load_yaml_file(file_path=file_path)
|
||||
240
dify/api/core/tools/workflow_as_tool/provider.py
Normal file
240
dify/api/core/tools/workflow_as_tool/provider.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolParameter,
|
||||
ToolProviderEntity,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow
|
||||
|
||||
VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
||||
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
|
||||
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
|
||||
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
|
||||
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
|
||||
VariableEntityType.CHECKBOX: ToolParameter.ToolParameterType.BOOLEAN,
|
||||
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
|
||||
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
tools: list[WorkflowTool] = Field(default_factory=list)
|
||||
|
||||
def __init__(self, entity: ToolProviderEntity, provider_id: str):
|
||||
super().__init__(entity=entity)
|
||||
self.provider_id = provider_id
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None
|
||||
if not provider:
|
||||
raise ValueError("workflow provider not found")
|
||||
app = session.get(App, provider.app_id)
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
user = session.get(Account, provider.user_id) if provider.user_id else None
|
||||
|
||||
controller = WorkflowToolProviderController(
|
||||
entity=ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
author=user.name if user else "",
|
||||
name=provider.label,
|
||||
label=I18nObject(en_US=provider.label, zh_Hans=provider.label),
|
||||
description=I18nObject(en_US=provider.description, zh_Hans=provider.description),
|
||||
icon=provider.icon,
|
||||
),
|
||||
credentials_schema=[],
|
||||
plugin_id=None,
|
||||
),
|
||||
provider_id=provider.id or "",
|
||||
)
|
||||
|
||||
controller.tools = [
|
||||
controller._get_db_provider_tool(provider, app, session=session, user=user),
|
||||
]
|
||||
|
||||
return controller
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
def _get_db_provider_tool(
|
||||
self,
|
||||
db_provider: WorkflowToolProvider,
|
||||
app: App,
|
||||
*,
|
||||
session: Session,
|
||||
user: Account | None = None,
|
||||
) -> WorkflowTool:
|
||||
"""
|
||||
get db provider tool
|
||||
:param db_provider: the db provider
|
||||
:param app: the app
|
||||
:return: the tool
|
||||
"""
|
||||
workflow: Workflow | None = (
|
||||
session.query(Workflow)
|
||||
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("workflow not found")
|
||||
|
||||
# fetch start node
|
||||
graph: Mapping = workflow.graph_dict
|
||||
features_dict: Mapping = workflow.features_dict
|
||||
features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW)
|
||||
|
||||
parameters = db_provider.parameter_configurations
|
||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||
|
||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||
|
||||
workflow_tool_parameters = []
|
||||
for parameter in parameters:
|
||||
variable = fetch_workflow_variable(parameter.name)
|
||||
if variable:
|
||||
parameter_type = None
|
||||
options = []
|
||||
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
|
||||
raise ValueError(f"unsupported variable type {variable.type}")
|
||||
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
|
||||
|
||||
if variable.type == VariableEntityType.SELECT and variable.options:
|
||||
options = [
|
||||
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in variable.options
|
||||
]
|
||||
|
||||
workflow_tool_parameters.append(
|
||||
ToolParameter(
|
||||
name=parameter.name,
|
||||
label=I18nObject(en_US=variable.label, zh_Hans=variable.label),
|
||||
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
|
||||
type=parameter_type,
|
||||
form=parameter.form,
|
||||
llm_description=parameter.description,
|
||||
required=variable.required,
|
||||
default=variable.default,
|
||||
options=options,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
)
|
||||
elif features.file_upload:
|
||||
workflow_tool_parameters.append(
|
||||
ToolParameter(
|
||||
name=parameter.name,
|
||||
label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name),
|
||||
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
|
||||
type=ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
llm_description=parameter.description,
|
||||
required=False,
|
||||
form=parameter.form,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("variable not found")
|
||||
|
||||
return WorkflowTool(
|
||||
workflow_as_tool_id=db_provider.id,
|
||||
entity=ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author=user.name if user else "",
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
|
||||
provider=self.provider_id,
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||
llm=db_provider.description,
|
||||
),
|
||||
parameters=workflow_tool_parameters,
|
||||
),
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
),
|
||||
workflow_app_id=app.id,
|
||||
workflow_entities={
|
||||
"app": app,
|
||||
"workflow": workflow,
|
||||
},
|
||||
version=db_provider.version,
|
||||
workflow_call_depth=0,
|
||||
label=db_provider.label,
|
||||
)
|
||||
|
||||
def get_tools(self, tenant_id: str) -> list[WorkflowTool]:
|
||||
"""
|
||||
fetch tools from database
|
||||
|
||||
:param tenant_id: the tenant id
|
||||
:return: the tools
|
||||
"""
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
db_provider: WorkflowToolProvider | None = (
|
||||
session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not db_provider:
|
||||
return []
|
||||
|
||||
app = session.get(App, db_provider.app_id)
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
|
||||
self.tools = [self._get_db_provider_tool(db_provider, app, session=session, user=user)]
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> WorkflowTool | None: # type: ignore
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
:return: the tool
|
||||
"""
|
||||
if self.tools is None:
|
||||
return None
|
||||
|
||||
for tool in self.tools:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
return None
|
||||
346
dify/api/core/tools/workflow_as_tool/tool.py
Normal file
346
dify/api/core/tools/workflow_as_tool/tool.py
Normal file
@@ -0,0 +1,346 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import has_request_context
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowTool(Tool):
|
||||
"""
|
||||
Workflow tool.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workflow_app_id: str,
|
||||
workflow_as_tool_id: str,
|
||||
version: str,
|
||||
workflow_entities: dict[str, Any],
|
||||
workflow_call_depth: int,
|
||||
entity: ToolEntity,
|
||||
runtime: ToolRuntime,
|
||||
label: str = "Workflow",
|
||||
):
|
||||
self.workflow_app_id = workflow_app_id
|
||||
self.workflow_as_tool_id = workflow_as_tool_id
|
||||
self.version = version
|
||||
self.workflow_entities = workflow_entities
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.label = label
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
get the tool provider type
|
||||
|
||||
:return: the tool provider type
|
||||
"""
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke the tool
|
||||
"""
|
||||
app = self._get_app(app_id=self.workflow_app_id)
|
||||
workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
|
||||
|
||||
# transform the tool parameters
|
||||
tool_parameters, files = self._transform_args(tool_parameters=tool_parameters)
|
||||
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
|
||||
generator = WorkflowAppGenerator()
|
||||
assert self.runtime is not None
|
||||
assert self.runtime.invoke_from is not None
|
||||
|
||||
user = self._resolve_user(user_id=user_id)
|
||||
if user is None:
|
||||
raise ToolInvokeError("User not found")
|
||||
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args={"inputs": tool_parameters, "files": files},
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
streaming=False,
|
||||
call_depth=self.workflow_call_depth + 1,
|
||||
)
|
||||
assert isinstance(result, dict)
|
||||
data = result.get("data", {})
|
||||
|
||||
if err := data.get("error"):
|
||||
raise ToolInvokeError(err)
|
||||
|
||||
outputs = data.get("outputs")
|
||||
if outputs is None:
|
||||
outputs = {}
|
||||
else:
|
||||
outputs, files = self._extract_files(outputs) # type: ignore
|
||||
for file in files:
|
||||
yield self.create_file_message(file) # type: ignore
|
||||
|
||||
self._latest_usage = self._derive_usage_from_result(data)
|
||||
|
||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||
yield self.create_json_message(outputs, suppress_output=True)
|
||||
|
||||
@property
|
||||
def latest_usage(self) -> LLMUsage:
|
||||
return self._latest_usage
|
||||
|
||||
@classmethod
|
||||
def _derive_usage_from_result(cls, data: Mapping[str, Any]) -> LLMUsage:
|
||||
usage_dict = cls._extract_usage_dict(data)
|
||||
if usage_dict is not None:
|
||||
return LLMUsage.from_metadata(cast(LLMUsageMetadata, dict(usage_dict)))
|
||||
|
||||
total_tokens = data.get("total_tokens")
|
||||
total_price = data.get("total_price")
|
||||
if total_tokens is None and total_price is None:
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
usage_metadata: dict[str, Any] = {}
|
||||
if total_tokens is not None:
|
||||
try:
|
||||
usage_metadata["total_tokens"] = int(str(total_tokens))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
if total_price is not None:
|
||||
usage_metadata["total_price"] = str(total_price)
|
||||
currency = data.get("currency")
|
||||
if currency is not None:
|
||||
usage_metadata["currency"] = currency
|
||||
|
||||
if not usage_metadata:
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
return LLMUsage.from_metadata(cast(LLMUsageMetadata, usage_metadata))
|
||||
|
||||
@classmethod
|
||||
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
|
||||
usage_candidate = payload.get("usage")
|
||||
if isinstance(usage_candidate, Mapping):
|
||||
return usage_candidate
|
||||
|
||||
metadata_candidate = payload.get("metadata")
|
||||
if isinstance(metadata_candidate, Mapping):
|
||||
usage_candidate = metadata_candidate.get("usage")
|
||||
if isinstance(usage_candidate, Mapping):
|
||||
return usage_candidate
|
||||
|
||||
for value in payload.values():
|
||||
if isinstance(value, Mapping):
|
||||
found = cls._extract_usage_dict(value)
|
||||
if found is not None:
|
||||
return found
|
||||
elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||
for item in value:
|
||||
if isinstance(item, Mapping):
|
||||
found = cls._extract_usage_dict(item)
|
||||
if found is not None:
|
||||
return found
|
||||
return None
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
workflow_app_id=self.workflow_app_id,
|
||||
workflow_as_tool_id=self.workflow_as_tool_id,
|
||||
workflow_entities=self.workflow_entities,
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
version=self.version,
|
||||
label=self.label,
|
||||
)
|
||||
|
||||
def _resolve_user(self, user_id: str) -> Account | EndUser | None:
|
||||
"""
|
||||
Resolve user object in both HTTP and worker contexts.
|
||||
|
||||
In HTTP context: dereference the current_user LocalProxy (can return Account or EndUser).
|
||||
In worker context: load Account from database by user_id (only returns Account, never EndUser).
|
||||
|
||||
Returns:
|
||||
Account | EndUser | None: The resolved user object, or None if resolution fails.
|
||||
"""
|
||||
if has_request_context():
|
||||
return self._resolve_user_from_request()
|
||||
else:
|
||||
return self._resolve_user_from_database(user_id=user_id)
|
||||
|
||||
def _resolve_user_from_request(self) -> Account | EndUser | None:
|
||||
"""
|
||||
Resolve user from Flask request context.
|
||||
"""
|
||||
try:
|
||||
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
|
||||
return getattr(current_user, "_get_current_object", lambda: current_user)()
|
||||
except Exception as e:
|
||||
logger.warning("Failed to resolve user from request context: %s", e)
|
||||
return None
|
||||
|
||||
def _resolve_user_from_database(self, user_id: str) -> Account | None:
|
||||
"""
|
||||
Resolve user from database (worker/Celery context).
|
||||
"""
|
||||
|
||||
user_stmt = select(Account).where(Account.id == user_id)
|
||||
user = db.session.scalar(user_stmt)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
|
||||
tenant = db.session.scalar(tenant_stmt)
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
user.current_tenant = tenant
|
||||
|
||||
return user
|
||||
|
||||
def _get_workflow(self, app_id: str, version: str) -> Workflow:
|
||||
"""
|
||||
get the workflow by app id and version
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
if not version:
|
||||
stmt = (
|
||||
select(Workflow)
|
||||
.where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
)
|
||||
workflow = session.scalars(stmt).first()
|
||||
else:
|
||||
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("workflow not found or not published")
|
||||
|
||||
return workflow
|
||||
|
||||
def _get_app(self, app_id: str) -> App:
|
||||
"""
|
||||
get the app by app id
|
||||
"""
|
||||
stmt = select(App).where(App.id == app_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
return app
|
||||
|
||||
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
|
||||
"""
|
||||
transform the tool parameters
|
||||
|
||||
:param tool_parameters: the tool parameters
|
||||
:return: tool_parameters, files
|
||||
"""
|
||||
parameter_rules = self.get_merged_runtime_parameters()
|
||||
parameters_result = {}
|
||||
files = []
|
||||
for parameter in parameter_rules:
|
||||
if parameter.type == ToolParameter.ToolParameterType.SYSTEM_FILES:
|
||||
file = tool_parameters.get(parameter.name)
|
||||
if file:
|
||||
try:
|
||||
file_var_list = [File.model_validate(f) for f in file]
|
||||
for file in file_var_list:
|
||||
file_dict: dict[str, str | None] = {
|
||||
"transfer_method": file.transfer_method.value,
|
||||
"type": file.type.value,
|
||||
}
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file.related_id
|
||||
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = file.related_id
|
||||
elif file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
file_dict["url"] = file.generate_url()
|
||||
|
||||
files.append(file_dict)
|
||||
except Exception:
|
||||
logger.exception("Failed to transform file %s", file)
|
||||
else:
|
||||
parameters_result[parameter.name] = tool_parameters.get(parameter.name)
|
||||
|
||||
return parameters_result, files
|
||||
|
||||
def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]:
|
||||
"""
|
||||
extract files from the result
|
||||
|
||||
:return: the result, files
|
||||
"""
|
||||
files: list[File] = []
|
||||
result = {}
|
||||
for key, value in outputs.items():
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
item = self._update_file_mapping(item)
|
||||
file = build_from_mapping(
|
||||
mapping=item,
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
)
|
||||
files.append(file)
|
||||
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
value = self._update_file_mapping(value)
|
||||
file = build_from_mapping(
|
||||
mapping=value,
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
result[key] = value
|
||||
|
||||
return result, files
|
||||
|
||||
def _update_file_mapping(self, file_dict: dict):
|
||||
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
|
||||
if transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file_dict.get("related_id")
|
||||
elif transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = file_dict.get("related_id")
|
||||
return file_dict
|
||||
Reference in New Issue
Block a user