dify
This commit is contained in:
9
dify/api/core/plugin/entities/base.py
Normal file
9
dify/api/core/plugin/entities/base.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BasePluginEntity(BaseModel):
|
||||
id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
30
dify/api/core/plugin/entities/bundle.py
Normal file
30
dify/api/core/plugin/entities/bundle.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin import PluginDeclaration, PluginInstallationSource
|
||||
|
||||
|
||||
class PluginBundleDependency(BaseModel):
|
||||
class Type(StrEnum):
|
||||
Github = PluginInstallationSource.Github.value
|
||||
Marketplace = PluginInstallationSource.Marketplace.value
|
||||
Package = PluginInstallationSource.Package.value
|
||||
|
||||
class Github(BaseModel):
|
||||
repo_address: str
|
||||
repo: str
|
||||
release: str
|
||||
packages: str
|
||||
|
||||
class Marketplace(BaseModel):
|
||||
organization: str
|
||||
plugin: str
|
||||
version: str
|
||||
|
||||
class Package(BaseModel):
|
||||
unique_identifier: str
|
||||
manifest: PluginDeclaration
|
||||
|
||||
type: Type
|
||||
value: Github | Marketplace | Package
|
||||
53
dify/api/core/plugin/entities/endpoint.py
Normal file
53
dify/api/core/plugin/entities/endpoint.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.base import BasePluginEntity
|
||||
|
||||
|
||||
class EndpointDeclaration(BaseModel):
|
||||
"""
|
||||
declaration of an endpoint
|
||||
"""
|
||||
|
||||
path: str
|
||||
method: str
|
||||
hidden: bool = Field(default=False)
|
||||
|
||||
|
||||
class EndpointProviderDeclaration(BaseModel):
|
||||
"""
|
||||
declaration of an endpoint group
|
||||
"""
|
||||
|
||||
settings: list[ProviderConfig] = Field(default_factory=list)
|
||||
endpoints: list[EndpointDeclaration] | None = Field(default_factory=list[EndpointDeclaration])
|
||||
|
||||
|
||||
class EndpointEntity(BasePluginEntity):
|
||||
"""
|
||||
entity of an endpoint
|
||||
"""
|
||||
|
||||
settings: dict
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
expired_at: datetime
|
||||
declaration: EndpointProviderDeclaration = Field(default_factory=EndpointProviderDeclaration)
|
||||
|
||||
|
||||
class EndpointEntityWithInstance(EndpointEntity):
|
||||
name: str
|
||||
enabled: bool
|
||||
url: str
|
||||
hook_id: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def render_url_template(cls, values):
|
||||
if "url" not in values:
|
||||
url_template = dify_config.ENDPOINT_URL_TEMPLATE
|
||||
values["url"] = url_template.replace("{hook_id}", values["hook_id"])
|
||||
return values
|
||||
50
dify/api/core/plugin/entities/marketplace.py
Normal file
50
dify/api/core/plugin/entities/marketplace.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||
from core.plugin.entities.plugin import PluginResourceRequirements
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
|
||||
|
||||
class MarketplacePluginDeclaration(BaseModel):
|
||||
name: str = Field(..., description="Unique identifier for the plugin within the marketplace")
|
||||
org: str = Field(..., description="Organization or developer responsible for creating and maintaining the plugin")
|
||||
plugin_id: str = Field(..., description="Globally unique identifier for the plugin across all marketplaces")
|
||||
icon: str = Field(..., description="URL or path to the plugin's visual representation")
|
||||
label: I18nObject = Field(..., description="Localized display name for the plugin in different languages")
|
||||
brief: I18nObject = Field(..., description="Short, localized description of the plugin's functionality")
|
||||
resource: PluginResourceRequirements = Field(
|
||||
..., description="Specification of computational resources needed to run the plugin"
|
||||
)
|
||||
endpoint: EndpointProviderDeclaration | None = Field(
|
||||
None, description="Configuration for the plugin's API endpoint, if applicable"
|
||||
)
|
||||
model: ProviderEntity | None = Field(None, description="Details of the AI model used by the plugin, if any")
|
||||
tool: ToolProviderEntity | None = Field(
|
||||
None, description="Information about the tool functionality provided by the plugin, if any"
|
||||
)
|
||||
latest_version: str = Field(
|
||||
..., description="Most recent version number of the plugin available in the marketplace"
|
||||
)
|
||||
latest_package_identifier: str = Field(
|
||||
..., description="Unique identifier for the latest package release of the plugin"
|
||||
)
|
||||
status: str = Field(..., description="Indicate the status of marketplace plugin, enum from `active` `deleted`")
|
||||
deprecated_reason: str = Field(
|
||||
..., description="Not empty when status='deleted', indicates the reason why this plugin is deleted(deprecated)"
|
||||
)
|
||||
alternative_plugin_id: str = Field(
|
||||
..., description="Optional, indicates the alternative plugin for user to switch to"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def transform_declaration(cls, data: dict):
|
||||
if "endpoint" in data and not data["endpoint"]:
|
||||
del data["endpoint"]
|
||||
if "model" in data and not data["model"]:
|
||||
del data["model"]
|
||||
if "tool" in data and not data["tool"]:
|
||||
del data["tool"]
|
||||
return data
|
||||
21
dify/api/core/plugin/entities/oauth.py
Normal file
21
dify/api/core/plugin/entities/oauth.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
"""
|
||||
OAuth schema
|
||||
"""
|
||||
|
||||
client_schema: Sequence[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="client schema like client_id, client_secret, etc.",
|
||||
)
|
||||
|
||||
credentials_schema: Sequence[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="credentials schema like access_token, refresh_token, etc.",
|
||||
)
|
||||
214
dify/api/core/plugin/entities/parameters.py
Normal file
214
dify/api/core/plugin/entities/parameters.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import json
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.entities.parameter_entities import CommonParameterType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class PluginParameterOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
icon: str | None = Field(default=None, description="The icon of the option, can be a url or a base64 encoded image")
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def transform_id_to_str(cls, value) -> str:
|
||||
if not isinstance(value, str):
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
class PluginParameterType(StrEnum):
|
||||
"""
|
||||
all available parameter types
|
||||
"""
|
||||
|
||||
STRING = CommonParameterType.STRING
|
||||
NUMBER = CommonParameterType.NUMBER
|
||||
BOOLEAN = CommonParameterType.BOOLEAN
|
||||
SELECT = CommonParameterType.SELECT
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT
|
||||
FILE = CommonParameterType.FILE
|
||||
FILES = CommonParameterType.FILES
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
||||
ANY = CommonParameterType.ANY
|
||||
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT
|
||||
CHECKBOX = CommonParameterType.CHECKBOX
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES
|
||||
|
||||
# MCP object and array type parameters
|
||||
ARRAY = CommonParameterType.ARRAY
|
||||
OBJECT = CommonParameterType.OBJECT
|
||||
|
||||
|
||||
class MCPServerParameterType(StrEnum):
|
||||
"""
|
||||
MCP server got complex parameter types
|
||||
"""
|
||||
|
||||
ARRAY = auto()
|
||||
OBJECT = auto()
|
||||
|
||||
|
||||
class PluginParameterAutoGenerate(BaseModel):
|
||||
class Type(StrEnum):
|
||||
PROMPT_INSTRUCTION = auto()
|
||||
|
||||
type: Type
|
||||
|
||||
|
||||
class PluginParameterTemplate(BaseModel):
|
||||
enabled: bool = Field(default=False, description="Whether the parameter is jinja enabled")
|
||||
|
||||
|
||||
class PluginParameter(BaseModel):
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
placeholder: I18nObject | None = Field(default=None, description="The placeholder presented to the user")
|
||||
scope: str | None = None
|
||||
auto_generate: PluginParameterAutoGenerate | None = None
|
||||
template: PluginParameterTemplate | None = None
|
||||
required: bool = False
|
||||
default: Union[float, int, str, bool] | None = None
|
||||
min: Union[float, int] | None = None
|
||||
max: Union[float, int] | None = None
|
||||
precision: int | None = None
|
||||
options: list[PluginParameterOption] = Field(default_factory=list)
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@classmethod
|
||||
def transform_options(cls, v):
|
||||
if not isinstance(v, list):
|
||||
return []
|
||||
return v
|
||||
|
||||
|
||||
def as_normal_type(typ: StrEnum):
|
||||
if typ.value in {
|
||||
PluginParameterType.SECRET_INPUT,
|
||||
PluginParameterType.SELECT,
|
||||
PluginParameterType.CHECKBOX,
|
||||
}:
|
||||
return "string"
|
||||
return typ.value
|
||||
|
||||
|
||||
def cast_parameter_value(typ: StrEnum, value: Any, /):
|
||||
try:
|
||||
match typ.value:
|
||||
case (
|
||||
PluginParameterType.STRING
|
||||
| PluginParameterType.SECRET_INPUT
|
||||
| PluginParameterType.SELECT
|
||||
| PluginParameterType.CHECKBOX
|
||||
| PluginParameterType.DYNAMIC_SELECT
|
||||
):
|
||||
if value is None:
|
||||
return ""
|
||||
else:
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
case PluginParameterType.BOOLEAN:
|
||||
if value is None:
|
||||
return False
|
||||
elif isinstance(value, str):
|
||||
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
||||
# and also '0' for False and '1' for True
|
||||
match value.lower():
|
||||
case "true" | "yes" | "y" | "1":
|
||||
return True
|
||||
case "false" | "no" | "n" | "0":
|
||||
return False
|
||||
case _:
|
||||
return bool(value)
|
||||
else:
|
||||
return value if isinstance(value, bool) else bool(value)
|
||||
|
||||
case PluginParameterType.NUMBER:
|
||||
if isinstance(value, int | float):
|
||||
return value
|
||||
elif isinstance(value, str) and value:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
case PluginParameterType.SYSTEM_FILES | PluginParameterType.FILES:
|
||||
if not isinstance(value, list):
|
||||
return [value]
|
||||
return value
|
||||
case PluginParameterType.FILE:
|
||||
if isinstance(value, list):
|
||||
if len(value) != 1:
|
||||
raise ValueError("This parameter only accepts one file but got multiple files while invoking.")
|
||||
else:
|
||||
return value[0]
|
||||
return value
|
||||
case PluginParameterType.MODEL_SELECTOR | PluginParameterType.APP_SELECTOR:
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("The selector must be a dictionary.")
|
||||
return value
|
||||
case PluginParameterType.TOOLS_SELECTOR:
|
||||
if value and not isinstance(value, list):
|
||||
raise ValueError("The tools selector must be a list.")
|
||||
return value
|
||||
case PluginParameterType.ANY:
|
||||
if value and not isinstance(value, str | dict | list | int | float):
|
||||
raise ValueError("The var selector must be a string, dictionary, list or number.")
|
||||
return value
|
||||
case PluginParameterType.ARRAY:
|
||||
if not isinstance(value, list):
|
||||
# Try to parse JSON string for arrays
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
parsed_value = json.loads(value)
|
||||
if isinstance(parsed_value, list):
|
||||
return parsed_value
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
return [value]
|
||||
return value
|
||||
case PluginParameterType.OBJECT:
|
||||
if not isinstance(value, dict):
|
||||
# Try to parse JSON string for objects
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
parsed_value = json.loads(value)
|
||||
if isinstance(parsed_value, dict):
|
||||
return parsed_value
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
return {}
|
||||
return value
|
||||
case _:
|
||||
return str(value)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
|
||||
|
||||
|
||||
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):
|
||||
"""
|
||||
init frontend parameter by rule
|
||||
"""
|
||||
parameter_value = value
|
||||
if not parameter_value and parameter_value != 0:
|
||||
# get default value
|
||||
parameter_value = rule.default
|
||||
if not parameter_value and rule.required:
|
||||
raise ValueError(f"tool parameter {rule.name} not found in tool config")
|
||||
|
||||
if type == PluginParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = [x.value for x in rule.options]
|
||||
if parameter_value is not None and parameter_value not in options:
|
||||
raise ValueError(f"tool parameter {rule.name} value {parameter_value} not in options {options}")
|
||||
|
||||
return cast_parameter_value(type, parameter_value)
|
||||
204
dify/api/core/plugin/entities/plugin.py
Normal file
204
dify/api/core/plugin/entities/plugin.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import datetime
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from core.agent.plugin_entities import AgentStrategyProviderEntity
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntity
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.base import BasePluginEntity
|
||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
from core.trigger.entities.entities import TriggerProviderEntity
|
||||
|
||||
|
||||
class PluginInstallationSource(StrEnum):
|
||||
Github = auto()
|
||||
Marketplace = auto()
|
||||
Package = auto()
|
||||
Remote = auto()
|
||||
|
||||
|
||||
class PluginResourceRequirements(BaseModel):
|
||||
memory: int
|
||||
|
||||
class Permission(BaseModel):
|
||||
class Tool(BaseModel):
|
||||
enabled: bool | None = Field(default=False)
|
||||
|
||||
class Model(BaseModel):
|
||||
enabled: bool | None = Field(default=False)
|
||||
llm: bool | None = Field(default=False)
|
||||
text_embedding: bool | None = Field(default=False)
|
||||
rerank: bool | None = Field(default=False)
|
||||
tts: bool | None = Field(default=False)
|
||||
speech2text: bool | None = Field(default=False)
|
||||
moderation: bool | None = Field(default=False)
|
||||
|
||||
class Node(BaseModel):
|
||||
enabled: bool | None = Field(default=False)
|
||||
|
||||
class Endpoint(BaseModel):
|
||||
enabled: bool | None = Field(default=False)
|
||||
|
||||
class Storage(BaseModel):
|
||||
enabled: bool | None = Field(default=False)
|
||||
size: int = Field(ge=1024, le=1073741824, default=1048576)
|
||||
|
||||
tool: Tool | None = Field(default=None)
|
||||
model: Model | None = Field(default=None)
|
||||
node: Node | None = Field(default=None)
|
||||
endpoint: Endpoint | None = Field(default=None)
|
||||
storage: Storage | None = Field(default=None)
|
||||
|
||||
permission: Permission | None = Field(default=None)
|
||||
|
||||
|
||||
class PluginCategory(StrEnum):
|
||||
Tool = auto()
|
||||
Model = auto()
|
||||
Extension = auto()
|
||||
AgentStrategy = "agent-strategy"
|
||||
Datasource = "datasource"
|
||||
Trigger = "trigger"
|
||||
|
||||
|
||||
class PluginDeclaration(BaseModel):
|
||||
class Plugins(BaseModel):
|
||||
tools: list[str] | None = Field(default_factory=list[str])
|
||||
models: list[str] | None = Field(default_factory=list[str])
|
||||
endpoints: list[str] | None = Field(default_factory=list[str])
|
||||
datasources: list[str] | None = Field(default_factory=list[str])
|
||||
triggers: list[str] | None = Field(default_factory=list[str])
|
||||
|
||||
class Meta(BaseModel):
|
||||
minimum_dify_version: str | None = Field(default=None)
|
||||
version: str | None = Field(default=None)
|
||||
|
||||
@field_validator("minimum_dify_version")
|
||||
@classmethod
|
||||
def validate_minimum_dify_version(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return v
|
||||
try:
|
||||
Version(v)
|
||||
return v
|
||||
except InvalidVersion as e:
|
||||
raise ValueError(f"Invalid version format: {v}") from e
|
||||
|
||||
version: str = Field(...)
|
||||
author: str | None = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$")
|
||||
name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$")
|
||||
description: I18nObject
|
||||
icon: str
|
||||
icon_dark: str | None = Field(default=None)
|
||||
label: I18nObject
|
||||
category: PluginCategory
|
||||
created_at: datetime.datetime
|
||||
resource: PluginResourceRequirements
|
||||
plugins: Plugins
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
repo: str | None = Field(default=None)
|
||||
verified: bool = Field(default=False)
|
||||
tool: ToolProviderEntity | None = None
|
||||
model: ProviderEntity | None = None
|
||||
endpoint: EndpointProviderDeclaration | None = None
|
||||
agent_strategy: AgentStrategyProviderEntity | None = None
|
||||
datasource: DatasourceProviderEntity | None = None
|
||||
trigger: TriggerProviderEntity | None = None
|
||||
meta: Meta
|
||||
|
||||
@field_validator("version")
|
||||
@classmethod
|
||||
def validate_version(cls, v: str) -> str:
|
||||
try:
|
||||
Version(v)
|
||||
return v
|
||||
except InvalidVersion as e:
|
||||
raise ValueError(f"Invalid version format: {v}") from e
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_category(cls, values: dict):
|
||||
# auto detect category
|
||||
if values.get("tool"):
|
||||
values["category"] = PluginCategory.Tool
|
||||
elif values.get("model"):
|
||||
values["category"] = PluginCategory.Model
|
||||
elif values.get("datasource"):
|
||||
values["category"] = PluginCategory.Datasource
|
||||
elif values.get("agent_strategy"):
|
||||
values["category"] = PluginCategory.AgentStrategy
|
||||
elif values.get("trigger"):
|
||||
values["category"] = PluginCategory.Trigger
|
||||
else:
|
||||
values["category"] = PluginCategory.Extension
|
||||
return values
|
||||
|
||||
|
||||
class PluginInstallation(BasePluginEntity):
|
||||
tenant_id: str
|
||||
endpoints_setups: int
|
||||
endpoints_active: int
|
||||
runtime_type: str
|
||||
source: PluginInstallationSource
|
||||
meta: Mapping[str, Any]
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
version: str
|
||||
checksum: str
|
||||
declaration: PluginDeclaration
|
||||
|
||||
|
||||
class PluginEntity(PluginInstallation):
|
||||
name: str
|
||||
installation_id: str
|
||||
version: str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_plugin_id(self):
|
||||
if self.declaration.tool:
|
||||
self.declaration.tool.plugin_id = self.plugin_id
|
||||
return self
|
||||
|
||||
|
||||
class PluginDependency(BaseModel):
|
||||
class Type(StrEnum):
|
||||
Github = PluginInstallationSource.Github
|
||||
Marketplace = PluginInstallationSource.Marketplace
|
||||
Package = PluginInstallationSource.Package
|
||||
|
||||
class Github(BaseModel):
|
||||
repo: str
|
||||
version: str
|
||||
package: str
|
||||
github_plugin_unique_identifier: str
|
||||
|
||||
@property
|
||||
def plugin_unique_identifier(self) -> str:
|
||||
return self.github_plugin_unique_identifier
|
||||
|
||||
class Marketplace(BaseModel):
|
||||
marketplace_plugin_unique_identifier: str
|
||||
version: str | None = None
|
||||
|
||||
@property
|
||||
def plugin_unique_identifier(self) -> str:
|
||||
return self.marketplace_plugin_unique_identifier
|
||||
|
||||
class Package(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
version: str | None = None
|
||||
|
||||
type: Type
|
||||
value: Github | Marketplace | Package
|
||||
current_identifier: str | None = None
|
||||
|
||||
|
||||
class MissingPluginDependency(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
current_identifier: str | None = None
|
||||
259
dify/api/core/plugin/entities/plugin_daemon.py
Normal file
259
dify/api/core/plugin/entities/plugin_daemon.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import enum
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.agent.plugin_entities import AgentProviderEntityWithPlugin
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.base import BasePluginEntity
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
||||
from core.trigger.entities.entities import TriggerProviderEntity
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
|
||||
class PluginDaemonBasicResponse(BaseModel, Generic[T]):
|
||||
"""
|
||||
Basic response from plugin daemon.
|
||||
"""
|
||||
|
||||
code: int
|
||||
message: str
|
||||
data: T | None = None
|
||||
|
||||
|
||||
class InstallPluginMessage(BaseModel):
|
||||
"""
|
||||
Message for installing a plugin.
|
||||
"""
|
||||
|
||||
class Event(StrEnum):
|
||||
Info = "info"
|
||||
Done = "done"
|
||||
Error = "error"
|
||||
|
||||
event: Event
|
||||
data: str
|
||||
|
||||
|
||||
class PluginToolProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
declaration: ToolProviderEntityWithPlugin
|
||||
|
||||
|
||||
class PluginDatasourceProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
is_authorized: bool = False
|
||||
declaration: DatasourceProviderEntityWithPlugin
|
||||
|
||||
|
||||
class PluginAgentProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
declaration: AgentProviderEntityWithPlugin
|
||||
meta: PluginDeclaration.Meta
|
||||
|
||||
|
||||
class PluginBasicBooleanResponse(BaseModel):
|
||||
"""
|
||||
Basic boolean response from plugin daemon.
|
||||
"""
|
||||
|
||||
result: bool
|
||||
credentials: dict | None = None
|
||||
|
||||
|
||||
class PluginModelSchemaEntity(BaseModel):
|
||||
model_schema: AIModelEntity = Field(description="The model schema.")
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class PluginModelProviderEntity(BaseModel):
|
||||
id: str = Field(description="ID")
|
||||
created_at: datetime = Field(description="The created at time of the model provider.")
|
||||
updated_at: datetime = Field(description="The updated at time of the model provider.")
|
||||
provider: str = Field(description="The provider of the model.")
|
||||
tenant_id: str = Field(description="The tenant ID.")
|
||||
plugin_unique_identifier: str = Field(description="The plugin unique identifier.")
|
||||
plugin_id: str = Field(description="The plugin ID.")
|
||||
declaration: ProviderEntity = Field(description="The declaration of the model provider.")
|
||||
|
||||
|
||||
class PluginTextEmbeddingNumTokensResponse(BaseModel):
|
||||
"""
|
||||
Response for number of tokens.
|
||||
"""
|
||||
|
||||
num_tokens: list[int] = Field(description="The number of tokens.")
|
||||
|
||||
|
||||
class PluginLLMNumTokensResponse(BaseModel):
|
||||
"""
|
||||
Response for number of tokens.
|
||||
"""
|
||||
|
||||
num_tokens: int = Field(description="The number of tokens.")
|
||||
|
||||
|
||||
class PluginStringResultResponse(BaseModel):
|
||||
result: str = Field(description="The result of the string.")
|
||||
|
||||
|
||||
class PluginVoiceEntity(BaseModel):
|
||||
name: str = Field(description="The name of the voice.")
|
||||
value: str = Field(description="The value of the voice.")
|
||||
|
||||
|
||||
class PluginVoicesResponse(BaseModel):
|
||||
voices: list[PluginVoiceEntity] = Field(description="The result of the voices.")
|
||||
|
||||
|
||||
class PluginDaemonError(BaseModel):
|
||||
"""
|
||||
Error from plugin daemon.
|
||||
"""
|
||||
|
||||
error_type: str
|
||||
message: str
|
||||
|
||||
|
||||
class PluginDaemonInnerError(Exception):
|
||||
code: int
|
||||
message: str
|
||||
|
||||
def __init__(self, code: int, message: str):
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class PluginInstallTaskStatus(StrEnum):
|
||||
Pending = "pending"
|
||||
Running = "running"
|
||||
Success = "success"
|
||||
Failed = "failed"
|
||||
|
||||
|
||||
class PluginInstallTaskPluginStatus(BaseModel):
|
||||
plugin_unique_identifier: str = Field(description="The plugin unique identifier of the install task.")
|
||||
plugin_id: str = Field(description="The plugin ID of the install task.")
|
||||
status: PluginInstallTaskStatus = Field(description="The status of the install task.")
|
||||
message: str = Field(description="The message of the install task.")
|
||||
icon: str = Field(description="The icon of the plugin.")
|
||||
labels: I18nObject = Field(description="The labels of the plugin.")
|
||||
|
||||
|
||||
class PluginInstallTask(BasePluginEntity):
|
||||
status: PluginInstallTaskStatus = Field(description="The status of the install task.")
|
||||
total_plugins: int = Field(description="The total number of plugins to be installed.")
|
||||
completed_plugins: int = Field(description="The number of plugins that have been installed.")
|
||||
plugins: list[PluginInstallTaskPluginStatus] = Field(description="The status of the plugins.")
|
||||
|
||||
|
||||
class PluginInstallTaskStartResponse(BaseModel):
|
||||
all_installed: bool = Field(description="Whether all plugins are installed.")
|
||||
task_id: str = Field(description="The ID of the install task.")
|
||||
|
||||
|
||||
class PluginVerification(BaseModel):
|
||||
"""
|
||||
Verification of the plugin.
|
||||
"""
|
||||
|
||||
class AuthorizedCategory(StrEnum):
|
||||
Langgenius = "langgenius"
|
||||
Partner = "partner"
|
||||
Community = "community"
|
||||
|
||||
authorized_category: AuthorizedCategory = Field(description="The authorized category of the plugin.")
|
||||
|
||||
|
||||
class PluginDecodeResponse(BaseModel):
|
||||
unique_identifier: str = Field(description="The unique identifier of the plugin.")
|
||||
manifest: PluginDeclaration
|
||||
verification: PluginVerification | None = Field(default=None, description="Basic verification information")
|
||||
|
||||
|
||||
class PluginOAuthAuthorizationUrlResponse(BaseModel):
|
||||
authorization_url: str = Field(description="The URL of the authorization.")
|
||||
|
||||
|
||||
class PluginOAuthCredentialsResponse(BaseModel):
|
||||
metadata: Mapping[str, Any] = Field(
|
||||
default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc."
|
||||
)
|
||||
expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.")
|
||||
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")
|
||||
|
||||
|
||||
class PluginListResponse(BaseModel):
|
||||
list: list[PluginEntity]
|
||||
total: int
|
||||
|
||||
|
||||
class PluginDynamicSelectOptionsResponse(BaseModel):
|
||||
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
|
||||
|
||||
|
||||
class PluginTriggerProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
declaration: TriggerProviderEntity
|
||||
|
||||
|
||||
class CredentialType(enum.StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = "oauth2"
|
||||
UNAUTHORIZED = "unauthorized"
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
return "API KEY"
|
||||
elif self == CredentialType.OAUTH2:
|
||||
return "AUTH"
|
||||
elif self == CredentialType.UNAUTHORIZED:
|
||||
return "UNAUTHORIZED"
|
||||
else:
|
||||
return self.value.replace("-", " ").upper()
|
||||
|
||||
def is_editable(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
def is_validate_allowed(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [item.value for item in cls]
|
||||
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name in {"api-key", "api_key"}:
|
||||
return cls.API_KEY
|
||||
elif type_name in {"oauth2", "oauth"}:
|
||||
return cls.OAUTH2
|
||||
elif type_name == "unauthorized":
|
||||
return cls.UNAUTHORIZED
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
|
||||
class PluginReadmeResponse(BaseModel):
|
||||
content: str = Field(description="The readme of the plugin.")
|
||||
language: str = Field(description="The language of the readme.")
|
||||
284
dify/api/core/plugin/entities/request.py
Normal file
284
dify/api/core/plugin/entities/request.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import binascii
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import Response
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.utils.http_parser import deserialize_response
|
||||
from core.workflow.nodes.parameter_extractor.entities import (
|
||||
ModelConfig as ParameterExtractorModelConfig,
|
||||
)
|
||||
from core.workflow.nodes.parameter_extractor.entities import (
|
||||
ParameterConfig,
|
||||
)
|
||||
from core.workflow.nodes.question_classifier.entities import (
|
||||
ClassConfig,
|
||||
)
|
||||
from core.workflow.nodes.question_classifier.entities import (
|
||||
ModelConfig as QuestionClassifierModelConfig,
|
||||
)
|
||||
|
||||
|
||||
class InvokeCredentials(BaseModel):
|
||||
tool_credentials: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of tool provider to credential id, used to store the credential id for the tool provider.",
|
||||
)
|
||||
|
||||
|
||||
class PluginInvokeContext(BaseModel):
|
||||
credentials: InvokeCredentials | None = Field(
|
||||
default_factory=InvokeCredentials,
|
||||
description="Credentials context for the plugin invocation or backward invocation.",
|
||||
)
|
||||
|
||||
|
||||
class RequestInvokeTool(BaseModel):
|
||||
"""
|
||||
Request to invoke a tool
|
||||
"""
|
||||
|
||||
tool_type: Literal["builtin", "workflow", "api", "mcp"]
|
||||
provider: str
|
||||
tool: str
|
||||
tool_parameters: dict
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
class BaseRequestInvokeModel(BaseModel):
|
||||
provider: str
|
||||
model: str
|
||||
model_type: ModelType
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class RequestInvokeLLM(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke LLM
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.LLM
|
||||
mode: str
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
prompt_messages: list[PromptMessage] = Field(default_factory=list)
|
||||
tools: list[PromptMessageTool] | None = Field(default_factory=list[PromptMessageTool])
|
||||
stop: list[str] | None = Field(default_factory=list[str])
|
||||
stream: bool | None = False
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_validator("prompt_messages", mode="before")
|
||||
@classmethod
|
||||
def convert_prompt_messages(cls, v):
|
||||
if not isinstance(v, list):
|
||||
raise ValueError("prompt_messages must be a list")
|
||||
|
||||
for i in range(len(v)):
|
||||
if v[i]["role"] == PromptMessageRole.USER:
|
||||
v[i] = UserPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.ASSISTANT:
|
||||
v[i] = AssistantPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.SYSTEM:
|
||||
v[i] = SystemPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.TOOL:
|
||||
v[i] = ToolPromptMessage.model_validate(v[i])
|
||||
else:
|
||||
v[i] = PromptMessage.model_validate(v[i])
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class RequestInvokeLLMWithStructuredOutput(RequestInvokeLLM):
|
||||
"""
|
||||
Request to invoke LLM with structured output
|
||||
"""
|
||||
|
||||
structured_output_schema: dict[str, Any] = Field(
|
||||
default_factory=dict, description="The schema of the structured output in JSON schema format"
|
||||
)
|
||||
|
||||
|
||||
class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke text embedding
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.TEXT_EMBEDDING
|
||||
texts: list[str]
|
||||
|
||||
|
||||
class RequestInvokeRerank(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke rerank
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.RERANK
|
||||
query: str
|
||||
docs: list[str]
|
||||
score_threshold: float
|
||||
top_n: int
|
||||
|
||||
|
||||
class RequestInvokeTTS(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke TTS
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.TTS
|
||||
content_text: str
|
||||
voice: str
|
||||
|
||||
|
||||
class RequestInvokeSpeech2Text(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke speech2text
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.SPEECH2TEXT
|
||||
file: bytes
|
||||
|
||||
@field_validator("file", mode="before")
|
||||
@classmethod
|
||||
def convert_file(cls, v):
|
||||
# hex string to bytes
|
||||
if isinstance(v, str):
|
||||
return bytes.fromhex(v)
|
||||
else:
|
||||
raise ValueError("file must be a hex string")
|
||||
|
||||
|
||||
class RequestInvokeModeration(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke moderation
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.MODERATION
|
||||
text: str
|
||||
|
||||
|
||||
class RequestInvokeParameterExtractorNode(BaseModel):
|
||||
"""
|
||||
Request to invoke parameter extractor node
|
||||
"""
|
||||
|
||||
parameters: list[ParameterConfig]
|
||||
model: ParameterExtractorModelConfig
|
||||
instruction: str
|
||||
query: str
|
||||
|
||||
|
||||
class RequestInvokeQuestionClassifierNode(BaseModel):
|
||||
"""
|
||||
Request to invoke question classifier node
|
||||
"""
|
||||
|
||||
query: str
|
||||
model: QuestionClassifierModelConfig
|
||||
classes: list[ClassConfig]
|
||||
instruction: str
|
||||
|
||||
|
||||
class RequestInvokeApp(BaseModel):
|
||||
"""
|
||||
Request to invoke app
|
||||
"""
|
||||
|
||||
app_id: str
|
||||
inputs: dict[str, Any]
|
||||
query: str | None = None
|
||||
response_mode: Literal["blocking", "streaming"]
|
||||
conversation_id: str | None = None
|
||||
user: str | None = None
|
||||
files: list[dict] = Field(default_factory=list)
|
||||
|
||||
|
||||
class RequestInvokeEncrypt(BaseModel):
|
||||
"""
|
||||
Request to encryption
|
||||
"""
|
||||
|
||||
opt: Literal["encrypt", "decrypt", "clear"]
|
||||
namespace: Literal["endpoint"]
|
||||
identity: str
|
||||
data: dict = Field(default_factory=dict)
|
||||
config: list[BasicProviderConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class RequestInvokeSummary(BaseModel):
|
||||
"""
|
||||
Request to summary
|
||||
"""
|
||||
|
||||
text: str
|
||||
instruction: str
|
||||
|
||||
|
||||
class RequestRequestUploadFile(BaseModel):
|
||||
"""
|
||||
Request to upload file
|
||||
"""
|
||||
|
||||
filename: str
|
||||
mimetype: str
|
||||
|
||||
|
||||
class RequestFetchAppInfo(BaseModel):
|
||||
"""
|
||||
Request to fetch app info
|
||||
"""
|
||||
|
||||
app_id: str
|
||||
|
||||
|
||||
class TriggerInvokeEventResponse(BaseModel):
|
||||
variables: Mapping[str, Any] = Field(default_factory=dict)
|
||||
cancelled: bool = Field(default=False)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)
|
||||
|
||||
@field_validator("variables", mode="before")
|
||||
@classmethod
|
||||
def convert_variables(cls, v):
|
||||
if isinstance(v, str):
|
||||
return json.loads(v)
|
||||
else:
|
||||
return v
|
||||
|
||||
|
||||
class TriggerSubscriptionResponse(BaseModel):
|
||||
subscription: dict[str, Any]
|
||||
|
||||
|
||||
class TriggerValidateProviderCredentialsResponse(BaseModel):
|
||||
result: bool
|
||||
|
||||
|
||||
class TriggerDispatchResponse(BaseModel):
|
||||
user_id: str
|
||||
events: list[str]
|
||||
response: Response
|
||||
payload: Mapping[str, Any] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)
|
||||
|
||||
@field_validator("response", mode="before")
|
||||
@classmethod
|
||||
def convert_response(cls, v: str):
|
||||
try:
|
||||
return deserialize_response(binascii.unhexlify(v.encode()))
|
||||
except Exception as e:
|
||||
raise ValueError("Failed to deserialize response from hex string") from e
|
||||
Reference in New Issue
Block a user