dify
This commit is contained in:
222
dify/api/core/plugin/backwards_invocation/app.py
Normal file
222
dify/api/core/plugin/backwards_invocation/app.py
Normal file
@@ -0,0 +1,222 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
||||
from core.app.apps.chat.app_generator import ChatAppGenerator
|
||||
from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.end_user_service import EndUserService
|
||||
|
||||
|
||||
class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
@classmethod
|
||||
def fetch_app_info(cls, app_id: str, tenant_id: str) -> Mapping:
|
||||
"""
|
||||
Fetch app info
|
||||
"""
|
||||
app = cls._get_app(app_id, tenant_id)
|
||||
|
||||
"""Retrieve app parameters."""
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return {
|
||||
"data": get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def invoke_app(
|
||||
cls,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str | None,
|
||||
query: str | None,
|
||||
stream: bool,
|
||||
inputs: Mapping,
|
||||
files: list[dict],
|
||||
) -> Generator[Mapping | str, None, None] | Mapping:
|
||||
"""
|
||||
invoke app
|
||||
"""
|
||||
app = cls._get_app(app_id, tenant_id)
|
||||
if not user_id:
|
||||
user = EndUserService.get_or_create_end_user(app)
|
||||
else:
|
||||
user = cls._get_user(user_id)
|
||||
|
||||
conversation_id = conversation_id or ""
|
||||
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}:
|
||||
if not query:
|
||||
raise ValueError("missing query")
|
||||
|
||||
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
|
||||
elif app.mode == AppMode.WORKFLOW:
|
||||
return cls.invoke_workflow_app(app, user, stream, inputs, files)
|
||||
elif app.mode == AppMode.COMPLETION:
|
||||
return cls.invoke_completion_app(app, user, stream, inputs, files)
|
||||
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
@classmethod
|
||||
def invoke_chat_app(
|
||||
cls,
|
||||
app: App,
|
||||
user: Account | EndUser,
|
||||
conversation_id: str,
|
||||
query: str,
|
||||
stream: bool,
|
||||
inputs: Mapping,
|
||||
files: list[dict],
|
||||
) -> Generator[Mapping | str, None, None] | Mapping:
|
||||
"""
|
||||
invoke chat app
|
||||
"""
|
||||
if app.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
return AdvancedChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"files": files,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
elif app.mode == AppMode.AGENT_CHAT:
|
||||
return AgentChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"files": files,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
elif app.mode == AppMode.CHAT:
|
||||
return ChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"files": files,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
@classmethod
|
||||
def invoke_workflow_app(
|
||||
cls,
|
||||
app: App,
|
||||
user: EndUser | Account,
|
||||
stream: bool,
|
||||
inputs: Mapping,
|
||||
files: list[dict],
|
||||
) -> Generator[Mapping | str, None, None] | Mapping:
|
||||
"""
|
||||
invoke workflow app
|
||||
"""
|
||||
workflow = app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
return WorkflowAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args={"inputs": inputs, "files": files},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
call_depth=1,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def invoke_completion_app(
|
||||
cls,
|
||||
app: App,
|
||||
user: EndUser | Account,
|
||||
stream: bool,
|
||||
inputs: Mapping,
|
||||
files: list[dict],
|
||||
) -> Generator[Mapping | str, None, None] | Mapping:
|
||||
"""
|
||||
invoke completion app
|
||||
"""
|
||||
return CompletionAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
args={"inputs": inputs, "files": files},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_user(cls, user_id: str) -> Union[EndUser, Account]:
|
||||
"""
|
||||
get the user by user id
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(EndUser).where(EndUser.id == user_id)
|
||||
user = session.scalar(stmt)
|
||||
if not user:
|
||||
stmt = select(Account).where(Account.id == user_id)
|
||||
user = session.scalar(stmt)
|
||||
|
||||
if not user:
|
||||
raise ValueError("user not found")
|
||||
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
def _get_app(cls, app_id: str, tenant_id: str) -> App:
|
||||
"""
|
||||
get app
|
||||
"""
|
||||
try:
|
||||
app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first()
|
||||
except Exception:
|
||||
raise ValueError("app not found")
|
||||
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
return app
|
||||
27
dify/api/core/plugin/backwards_invocation/base.py
Normal file
27
dify/api/core/plugin/backwards_invocation/base.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseBackwardsInvocation:
|
||||
@classmethod
|
||||
def convert_to_event_stream(cls, response: Generator[BaseModel | Mapping | str, None, None] | BaseModel | Mapping):
|
||||
if isinstance(response, Generator):
|
||||
try:
|
||||
for chunk in response:
|
||||
if isinstance(chunk, BaseModel | dict):
|
||||
yield BaseBackwardsInvocationResponse(data=chunk).model_dump_json().encode()
|
||||
except Exception as e:
|
||||
error_message = BaseBackwardsInvocationResponse(error=str(e)).model_dump_json()
|
||||
yield error_message.encode()
|
||||
else:
|
||||
yield BaseBackwardsInvocationResponse(data=response).model_dump_json().encode()
|
||||
|
||||
|
||||
T = TypeVar("T", bound=dict | Mapping | str | bool | int | BaseModel)
|
||||
|
||||
|
||||
class BaseBackwardsInvocationResponse(BaseModel, Generic[T]):
|
||||
data: T | None = None
|
||||
error: str = ""
|
||||
34
dify/api/core/plugin/backwards_invocation/encrypt.py
Normal file
34
dify/api/core/plugin/backwards_invocation/encrypt.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
||||
from core.plugin.entities.request import RequestInvokeEncrypt
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
class PluginEncrypter:
|
||||
@classmethod
|
||||
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt):
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant.id,
|
||||
config=payload.config,
|
||||
cache=SingletonProviderCredentialsCache(
|
||||
tenant_id=tenant.id,
|
||||
provider_type=payload.namespace,
|
||||
provider_identity=payload.identity,
|
||||
),
|
||||
)
|
||||
|
||||
if payload.opt == "encrypt":
|
||||
return {
|
||||
"data": encrypter.encrypt(payload.data),
|
||||
}
|
||||
elif payload.opt == "decrypt":
|
||||
return {
|
||||
"data": encrypter.decrypt(payload.data),
|
||||
}
|
||||
elif payload.opt == "clear":
|
||||
cache.delete()
|
||||
return {
|
||||
"data": {},
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Invalid opt: {payload.opt}")
|
||||
410
dify/api/core/plugin/backwards_invocation/model.py
Normal file
410
dify/api/core/plugin/backwards_invocation/model.py
Normal file
@@ -0,0 +1,410 @@
|
||||
import tempfile
|
||||
from binascii import hexlify, unhexlify
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeLLMWithStructuredOutput,
|
||||
RequestInvokeModeration,
|
||||
RequestInvokeRerank,
|
||||
RequestInvokeSpeech2Text,
|
||||
RequestInvokeSummary,
|
||||
RequestInvokeTextEmbedding,
|
||||
RequestInvokeTTS,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
@classmethod
|
||||
def invoke_llm(
|
||||
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
|
||||
) -> Generator[LLMResultChunk, None, None] | LLMResult:
|
||||
"""
|
||||
invoke llm
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=payload.prompt_messages,
|
||||
model_parameters=payload.completion_params,
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
|
||||
def handle() -> Generator[LLMResultChunk, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
llm_utils.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=response.model,
|
||||
prompt_messages=[],
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=response.message,
|
||||
usage=response.usage,
|
||||
finish_reason="",
|
||||
),
|
||||
)
|
||||
|
||||
return handle_non_streaming(response)
|
||||
|
||||
@classmethod
|
||||
def invoke_llm_with_structured_output(
|
||||
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLMWithStructuredOutput
|
||||
):
|
||||
"""
|
||||
invoke llm with structured output
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(payload.model, model_instance.credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {payload.model}")
|
||||
|
||||
response = invoke_llm_with_structured_output(
|
||||
provider=payload.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=payload.prompt_messages,
|
||||
json_schema=payload.structured_output_schema,
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
model_parameters=payload.completion_params,
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
|
||||
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
llm_utils.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(
|
||||
response: LLMResultWithStructuredOutput,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=response.model,
|
||||
prompt_messages=[],
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
structured_output=response.structured_output,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=response.message,
|
||||
usage=response.usage,
|
||||
finish_reason="",
|
||||
),
|
||||
)
|
||||
|
||||
return handle_non_streaming(response)
|
||||
|
||||
@classmethod
|
||||
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
|
||||
"""
|
||||
invoke text embedding
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_text_embedding(
|
||||
texts=payload.texts,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def invoke_rerank(cls, user_id: str, tenant: Tenant, payload: RequestInvokeRerank):
|
||||
"""
|
||||
invoke rerank
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_rerank(
|
||||
query=payload.query,
|
||||
docs=payload.docs,
|
||||
score_threshold=payload.score_threshold,
|
||||
top_n=payload.top_n,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def invoke_tts(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTTS):
|
||||
"""
|
||||
invoke tts
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_tts(
|
||||
content_text=payload.content_text,
|
||||
tenant_id=tenant.id,
|
||||
voice=payload.voice,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
def handle() -> Generator[dict, None, None]:
|
||||
for chunk in response:
|
||||
yield {"result": hexlify(chunk).decode("utf-8")}
|
||||
|
||||
return handle()
|
||||
|
||||
@classmethod
|
||||
def invoke_speech2text(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSpeech2Text):
|
||||
"""
|
||||
invoke speech2text
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp:
|
||||
temp.write(unhexlify(payload.file))
|
||||
temp.flush()
|
||||
temp.seek(0)
|
||||
|
||||
response = model_instance.invoke_speech2text(
|
||||
file=temp,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": response,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def invoke_moderation(cls, user_id: str, tenant: Tenant, payload: RequestInvokeModeration):
|
||||
"""
|
||||
invoke moderation
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_moderation(
|
||||
text=payload.text,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": response,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_system_model_max_tokens(cls, tenant_id: str) -> int:
|
||||
"""
|
||||
get system model max tokens
|
||||
"""
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id)
|
||||
|
||||
@classmethod
|
||||
def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
|
||||
"""
|
||||
get prompt tokens
|
||||
"""
|
||||
return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages)
|
||||
|
||||
@classmethod
|
||||
def invoke_system_model(
|
||||
cls,
|
||||
user_id: str,
|
||||
tenant: Tenant,
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
invoke system model
|
||||
"""
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant.id,
|
||||
tool_type=ToolProviderType.PLUGIN,
|
||||
tool_name="plugin",
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def invoke_summary(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSummary):
|
||||
"""
|
||||
invoke summary
|
||||
"""
|
||||
max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id)
|
||||
content = payload.text
|
||||
|
||||
SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
|
||||
and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
|
||||
retain the original meaning and keep the key points.
|
||||
however, the text you got is too long, what you got is possible a part of the text.
|
||||
Please summarize the text you got.
|
||||
|
||||
Here is the extra instruction you need to follow:
|
||||
<extra_instruction>
|
||||
{payload.instruction}
|
||||
</extra_instruction>
|
||||
"""
|
||||
|
||||
if (
|
||||
cls.get_prompt_tokens(
|
||||
tenant_id=tenant.id,
|
||||
prompt_messages=[UserPromptMessage(content=content)],
|
||||
)
|
||||
< max_tokens * 0.6
|
||||
):
|
||||
return content
|
||||
|
||||
def get_prompt_tokens(content: str) -> int:
|
||||
return cls.get_prompt_tokens(
|
||||
tenant_id=tenant.id,
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
|
||||
UserPromptMessage(content=content),
|
||||
],
|
||||
)
|
||||
|
||||
def summarize(content: str) -> str:
|
||||
summary = cls.invoke_system_model(
|
||||
user_id=user_id,
|
||||
tenant=tenant,
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
|
||||
UserPromptMessage(content=content),
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(summary.message.content, str)
|
||||
return summary.message.content
|
||||
|
||||
lines = content.split("\n")
|
||||
new_lines: list[str] = []
|
||||
# split long line into multiple lines
|
||||
for i in range(len(lines)):
|
||||
line = lines[i]
|
||||
if not line.strip():
|
||||
continue
|
||||
if len(line) < max_tokens * 0.5:
|
||||
new_lines.append(line)
|
||||
elif get_prompt_tokens(line) > max_tokens * 0.7:
|
||||
while get_prompt_tokens(line) > max_tokens * 0.7:
|
||||
new_lines.append(line[: int(max_tokens * 0.5)])
|
||||
line = line[int(max_tokens * 0.5) :]
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line)
|
||||
|
||||
# merge lines into messages with max tokens
|
||||
messages: list[str] = []
|
||||
for line in new_lines:
|
||||
if len(messages) == 0:
|
||||
messages.append(line)
|
||||
else:
|
||||
if len(messages[-1]) + len(line) < max_tokens * 0.5:
|
||||
messages[-1] += line
|
||||
if get_prompt_tokens(messages[-1] + line) > max_tokens * 0.7:
|
||||
messages.append(line)
|
||||
else:
|
||||
messages[-1] += line
|
||||
|
||||
summaries = []
|
||||
for i in range(len(messages)):
|
||||
message = messages[i]
|
||||
summary = summarize(message)
|
||||
summaries.append(summary)
|
||||
|
||||
result = "\n".join(summaries)
|
||||
|
||||
if (
|
||||
cls.get_prompt_tokens(
|
||||
tenant_id=tenant.id,
|
||||
prompt_messages=[UserPromptMessage(content=result)],
|
||||
)
|
||||
> max_tokens * 0.7
|
||||
):
|
||||
return cls.invoke_summary(
|
||||
user_id=user_id,
|
||||
tenant=tenant,
|
||||
payload=RequestInvokeSummary(text=result, instruction=payload.instruction),
|
||||
)
|
||||
|
||||
return result
|
||||
119
dify/api/core/plugin/backwards_invocation/node.py
Normal file
119
dify/api/core/plugin/backwards_invocation/node.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.parameter_extractor.entities import (
|
||||
ModelConfig as ParameterExtractorModelConfig,
|
||||
)
|
||||
from core.workflow.nodes.parameter_extractor.entities import (
|
||||
ParameterConfig,
|
||||
ParameterExtractorNodeData,
|
||||
)
|
||||
from core.workflow.nodes.question_classifier.entities import (
|
||||
ClassConfig,
|
||||
QuestionClassifierNodeData,
|
||||
)
|
||||
from core.workflow.nodes.question_classifier.entities import (
|
||||
ModelConfig as QuestionClassifierModelConfig,
|
||||
)
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
||||
@classmethod
|
||||
def invoke_parameter_extractor(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
parameters: list[ParameterConfig],
|
||||
model_config: ParameterExtractorModelConfig,
|
||||
instruction: str,
|
||||
query: str,
|
||||
):
|
||||
"""
|
||||
Invoke parameter extractor node.
|
||||
|
||||
:param tenant_id: str
|
||||
:param user_id: str
|
||||
:param parameters: list[ParameterConfig]
|
||||
:param model_config: ModelConfig
|
||||
:param instruction: str
|
||||
:param query: str
|
||||
:return: dict
|
||||
"""
|
||||
# FIXME(-LAN-): Avoid import service into core
|
||||
workflow_service = WorkflowService()
|
||||
node_id = "1919810"
|
||||
node_data = ParameterExtractorNodeData(
|
||||
title="parameter_extractor",
|
||||
desc="parameter_extractor",
|
||||
parameters=parameters,
|
||||
reasoning_mode="function_call",
|
||||
query=[node_id, "query"],
|
||||
model=model_config,
|
||||
instruction=instruction, # instruct with variables are not supported
|
||||
)
|
||||
node_data_dict = node_data.model_dump()
|
||||
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR
|
||||
execution = workflow_service.run_free_workflow_node(
|
||||
node_data_dict,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
user_inputs={
|
||||
f"{node_id}.query": query,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"inputs": execution.inputs,
|
||||
"outputs": execution.outputs,
|
||||
"process_data": execution.process_data,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def invoke_question_classifier(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
model_config: QuestionClassifierModelConfig,
|
||||
classes: list[ClassConfig],
|
||||
instruction: str,
|
||||
query: str,
|
||||
):
|
||||
"""
|
||||
Invoke question classifier node.
|
||||
|
||||
:param tenant_id: str
|
||||
:param user_id: str
|
||||
:param model_config: ModelConfig
|
||||
:param classes: list[ClassConfig]
|
||||
:param instruction: str
|
||||
:param query: str
|
||||
:return: dict
|
||||
"""
|
||||
# FIXME(-LAN-): Avoid import service into core
|
||||
workflow_service = WorkflowService()
|
||||
node_id = "1919810"
|
||||
node_data = QuestionClassifierNodeData(
|
||||
title="question_classifier",
|
||||
desc="question_classifier",
|
||||
query_variable_selector=[node_id, "query"],
|
||||
model=model_config,
|
||||
classes=classes,
|
||||
instruction=instruction, # instruct with variables are not supported
|
||||
)
|
||||
node_data_dict = node_data.model_dump()
|
||||
execution = workflow_service.run_free_workflow_node(
|
||||
node_data_dict,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
user_inputs={
|
||||
f"{node_id}.query": query,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"inputs": execution.inputs,
|
||||
"outputs": execution.outputs,
|
||||
"process_data": execution.process_data,
|
||||
}
|
||||
46
dify/api/core/plugin/backwards_invocation/tool.py
Normal file
46
dify/api/core/plugin/backwards_invocation/tool.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
|
||||
|
||||
class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
Backwards invocation for plugin tools.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def invoke_tool(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
tool_type: ToolProviderType,
|
||||
provider: str,
|
||||
tool_name: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
credential_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke tool
|
||||
"""
|
||||
# get tool runtime
|
||||
try:
|
||||
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
|
||||
tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
|
||||
)
|
||||
response = ToolEngine.generic_invoke(
|
||||
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1
|
||||
)
|
||||
|
||||
response = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
response, user_id=user_id, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
6
dify/api/core/plugin/endpoint/exc.py
Normal file
6
dify/api/core/plugin/endpoint/exc.py
Normal file
@@ -0,0 +1,6 @@
|
||||
class EndpointSetupFailedError(ValueError):
|
||||
"""
|
||||
Endpoint setup failed error
|
||||
"""
|
||||
|
||||
pass
|
||||
9
dify/api/core/plugin/entities/base.py
Normal file
9
dify/api/core/plugin/entities/base.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BasePluginEntity(BaseModel):
|
||||
id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
30
dify/api/core/plugin/entities/bundle.py
Normal file
30
dify/api/core/plugin/entities/bundle.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin import PluginDeclaration, PluginInstallationSource
|
||||
|
||||
|
||||
class PluginBundleDependency(BaseModel):
|
||||
class Type(StrEnum):
|
||||
Github = PluginInstallationSource.Github.value
|
||||
Marketplace = PluginInstallationSource.Marketplace.value
|
||||
Package = PluginInstallationSource.Package.value
|
||||
|
||||
class Github(BaseModel):
|
||||
repo_address: str
|
||||
repo: str
|
||||
release: str
|
||||
packages: str
|
||||
|
||||
class Marketplace(BaseModel):
|
||||
organization: str
|
||||
plugin: str
|
||||
version: str
|
||||
|
||||
class Package(BaseModel):
|
||||
unique_identifier: str
|
||||
manifest: PluginDeclaration
|
||||
|
||||
type: Type
|
||||
value: Github | Marketplace | Package
|
||||
53
dify/api/core/plugin/entities/endpoint.py
Normal file
53
dify/api/core/plugin/entities/endpoint.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.base import BasePluginEntity
|
||||
|
||||
|
||||
class EndpointDeclaration(BaseModel):
|
||||
"""
|
||||
declaration of an endpoint
|
||||
"""
|
||||
|
||||
path: str
|
||||
method: str
|
||||
hidden: bool = Field(default=False)
|
||||
|
||||
|
||||
class EndpointProviderDeclaration(BaseModel):
|
||||
"""
|
||||
declaration of an endpoint group
|
||||
"""
|
||||
|
||||
settings: list[ProviderConfig] = Field(default_factory=list)
|
||||
endpoints: list[EndpointDeclaration] | None = Field(default_factory=list[EndpointDeclaration])
|
||||
|
||||
|
||||
class EndpointEntity(BasePluginEntity):
|
||||
"""
|
||||
entity of an endpoint
|
||||
"""
|
||||
|
||||
settings: dict
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
expired_at: datetime
|
||||
declaration: EndpointProviderDeclaration = Field(default_factory=EndpointProviderDeclaration)
|
||||
|
||||
|
||||
class EndpointEntityWithInstance(EndpointEntity):
|
||||
name: str
|
||||
enabled: bool
|
||||
url: str
|
||||
hook_id: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def render_url_template(cls, values):
|
||||
if "url" not in values:
|
||||
url_template = dify_config.ENDPOINT_URL_TEMPLATE
|
||||
values["url"] = url_template.replace("{hook_id}", values["hook_id"])
|
||||
return values
|
||||
50
dify/api/core/plugin/entities/marketplace.py
Normal file
50
dify/api/core/plugin/entities/marketplace.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||
from core.plugin.entities.plugin import PluginResourceRequirements
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
|
||||
|
||||
class MarketplacePluginDeclaration(BaseModel):
|
||||
name: str = Field(..., description="Unique identifier for the plugin within the marketplace")
|
||||
org: str = Field(..., description="Organization or developer responsible for creating and maintaining the plugin")
|
||||
plugin_id: str = Field(..., description="Globally unique identifier for the plugin across all marketplaces")
|
||||
icon: str = Field(..., description="URL or path to the plugin's visual representation")
|
||||
label: I18nObject = Field(..., description="Localized display name for the plugin in different languages")
|
||||
brief: I18nObject = Field(..., description="Short, localized description of the plugin's functionality")
|
||||
resource: PluginResourceRequirements = Field(
|
||||
..., description="Specification of computational resources needed to run the plugin"
|
||||
)
|
||||
endpoint: EndpointProviderDeclaration | None = Field(
|
||||
None, description="Configuration for the plugin's API endpoint, if applicable"
|
||||
)
|
||||
model: ProviderEntity | None = Field(None, description="Details of the AI model used by the plugin, if any")
|
||||
tool: ToolProviderEntity | None = Field(
|
||||
None, description="Information about the tool functionality provided by the plugin, if any"
|
||||
)
|
||||
latest_version: str = Field(
|
||||
..., description="Most recent version number of the plugin available in the marketplace"
|
||||
)
|
||||
latest_package_identifier: str = Field(
|
||||
..., description="Unique identifier for the latest package release of the plugin"
|
||||
)
|
||||
status: str = Field(..., description="Indicate the status of marketplace plugin, enum from `active` `deleted`")
|
||||
deprecated_reason: str = Field(
|
||||
..., description="Not empty when status='deleted', indicates the reason why this plugin is deleted(deprecated)"
|
||||
)
|
||||
alternative_plugin_id: str = Field(
|
||||
..., description="Optional, indicates the alternative plugin for user to switch to"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def transform_declaration(cls, data: dict):
|
||||
if "endpoint" in data and not data["endpoint"]:
|
||||
del data["endpoint"]
|
||||
if "model" in data and not data["model"]:
|
||||
del data["model"]
|
||||
if "tool" in data and not data["tool"]:
|
||||
del data["tool"]
|
||||
return data
|
||||
21
dify/api/core/plugin/entities/oauth.py
Normal file
21
dify/api/core/plugin/entities/oauth.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
"""
|
||||
OAuth schema
|
||||
"""
|
||||
|
||||
client_schema: Sequence[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="client schema like client_id, client_secret, etc.",
|
||||
)
|
||||
|
||||
credentials_schema: Sequence[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="credentials schema like access_token, refresh_token, etc.",
|
||||
)
|
||||
214
dify/api/core/plugin/entities/parameters.py
Normal file
214
dify/api/core/plugin/entities/parameters.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import json
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.entities.parameter_entities import CommonParameterType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class PluginParameterOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
icon: str | None = Field(default=None, description="The icon of the option, can be a url or a base64 encoded image")
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def transform_id_to_str(cls, value) -> str:
|
||||
if not isinstance(value, str):
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
class PluginParameterType(StrEnum):
|
||||
"""
|
||||
all available parameter types
|
||||
"""
|
||||
|
||||
STRING = CommonParameterType.STRING
|
||||
NUMBER = CommonParameterType.NUMBER
|
||||
BOOLEAN = CommonParameterType.BOOLEAN
|
||||
SELECT = CommonParameterType.SELECT
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT
|
||||
FILE = CommonParameterType.FILE
|
||||
FILES = CommonParameterType.FILES
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
||||
ANY = CommonParameterType.ANY
|
||||
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT
|
||||
CHECKBOX = CommonParameterType.CHECKBOX
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES
|
||||
|
||||
# MCP object and array type parameters
|
||||
ARRAY = CommonParameterType.ARRAY
|
||||
OBJECT = CommonParameterType.OBJECT
|
||||
|
||||
|
||||
class MCPServerParameterType(StrEnum):
|
||||
"""
|
||||
MCP server got complex parameter types
|
||||
"""
|
||||
|
||||
ARRAY = auto()
|
||||
OBJECT = auto()
|
||||
|
||||
|
||||
class PluginParameterAutoGenerate(BaseModel):
|
||||
class Type(StrEnum):
|
||||
PROMPT_INSTRUCTION = auto()
|
||||
|
||||
type: Type
|
||||
|
||||
|
||||
class PluginParameterTemplate(BaseModel):
|
||||
enabled: bool = Field(default=False, description="Whether the parameter is jinja enabled")
|
||||
|
||||
|
||||
class PluginParameter(BaseModel):
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
placeholder: I18nObject | None = Field(default=None, description="The placeholder presented to the user")
|
||||
scope: str | None = None
|
||||
auto_generate: PluginParameterAutoGenerate | None = None
|
||||
template: PluginParameterTemplate | None = None
|
||||
required: bool = False
|
||||
default: Union[float, int, str, bool] | None = None
|
||||
min: Union[float, int] | None = None
|
||||
max: Union[float, int] | None = None
|
||||
precision: int | None = None
|
||||
options: list[PluginParameterOption] = Field(default_factory=list)
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@classmethod
|
||||
def transform_options(cls, v):
|
||||
if not isinstance(v, list):
|
||||
return []
|
||||
return v
|
||||
|
||||
|
||||
def as_normal_type(typ: StrEnum):
|
||||
if typ.value in {
|
||||
PluginParameterType.SECRET_INPUT,
|
||||
PluginParameterType.SELECT,
|
||||
PluginParameterType.CHECKBOX,
|
||||
}:
|
||||
return "string"
|
||||
return typ.value
|
||||
|
||||
|
||||
def cast_parameter_value(typ: StrEnum, value: Any, /):
|
||||
try:
|
||||
match typ.value:
|
||||
case (
|
||||
PluginParameterType.STRING
|
||||
| PluginParameterType.SECRET_INPUT
|
||||
| PluginParameterType.SELECT
|
||||
| PluginParameterType.CHECKBOX
|
||||
| PluginParameterType.DYNAMIC_SELECT
|
||||
):
|
||||
if value is None:
|
||||
return ""
|
||||
else:
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
case PluginParameterType.BOOLEAN:
|
||||
if value is None:
|
||||
return False
|
||||
elif isinstance(value, str):
|
||||
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
||||
# and also '0' for False and '1' for True
|
||||
match value.lower():
|
||||
case "true" | "yes" | "y" | "1":
|
||||
return True
|
||||
case "false" | "no" | "n" | "0":
|
||||
return False
|
||||
case _:
|
||||
return bool(value)
|
||||
else:
|
||||
return value if isinstance(value, bool) else bool(value)
|
||||
|
||||
case PluginParameterType.NUMBER:
|
||||
if isinstance(value, int | float):
|
||||
return value
|
||||
elif isinstance(value, str) and value:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
case PluginParameterType.SYSTEM_FILES | PluginParameterType.FILES:
|
||||
if not isinstance(value, list):
|
||||
return [value]
|
||||
return value
|
||||
case PluginParameterType.FILE:
|
||||
if isinstance(value, list):
|
||||
if len(value) != 1:
|
||||
raise ValueError("This parameter only accepts one file but got multiple files while invoking.")
|
||||
else:
|
||||
return value[0]
|
||||
return value
|
||||
case PluginParameterType.MODEL_SELECTOR | PluginParameterType.APP_SELECTOR:
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("The selector must be a dictionary.")
|
||||
return value
|
||||
case PluginParameterType.TOOLS_SELECTOR:
|
||||
if value and not isinstance(value, list):
|
||||
raise ValueError("The tools selector must be a list.")
|
||||
return value
|
||||
case PluginParameterType.ANY:
|
||||
if value and not isinstance(value, str | dict | list | int | float):
|
||||
raise ValueError("The var selector must be a string, dictionary, list or number.")
|
||||
return value
|
||||
case PluginParameterType.ARRAY:
|
||||
if not isinstance(value, list):
|
||||
# Try to parse JSON string for arrays
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
parsed_value = json.loads(value)
|
||||
if isinstance(parsed_value, list):
|
||||
return parsed_value
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
return [value]
|
||||
return value
|
||||
case PluginParameterType.OBJECT:
|
||||
if not isinstance(value, dict):
|
||||
# Try to parse JSON string for objects
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
parsed_value = json.loads(value)
|
||||
if isinstance(parsed_value, dict):
|
||||
return parsed_value
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
return {}
|
||||
return value
|
||||
case _:
|
||||
return str(value)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
|
||||
|
||||
|
||||
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):
|
||||
"""
|
||||
init frontend parameter by rule
|
||||
"""
|
||||
parameter_value = value
|
||||
if not parameter_value and parameter_value != 0:
|
||||
# get default value
|
||||
parameter_value = rule.default
|
||||
if not parameter_value and rule.required:
|
||||
raise ValueError(f"tool parameter {rule.name} not found in tool config")
|
||||
|
||||
if type == PluginParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = [x.value for x in rule.options]
|
||||
if parameter_value is not None and parameter_value not in options:
|
||||
raise ValueError(f"tool parameter {rule.name} value {parameter_value} not in options {options}")
|
||||
|
||||
return cast_parameter_value(type, parameter_value)
|
||||
204
dify/api/core/plugin/entities/plugin.py
Normal file
204
dify/api/core/plugin/entities/plugin.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import datetime
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from core.agent.plugin_entities import AgentStrategyProviderEntity
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntity
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.base import BasePluginEntity
|
||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
from core.trigger.entities.entities import TriggerProviderEntity
|
||||
|
||||
|
||||
class PluginInstallationSource(StrEnum):
|
||||
Github = auto()
|
||||
Marketplace = auto()
|
||||
Package = auto()
|
||||
Remote = auto()
|
||||
|
||||
|
||||
class PluginResourceRequirements(BaseModel):
|
||||
memory: int
|
||||
|
||||
class Permission(BaseModel):
|
||||
class Tool(BaseModel):
|
||||
enabled: bool | None = Field(default=False)
|
||||
|
||||
class Model(BaseModel):
|
||||
enabled: bool | None = Field(default=False)
|
||||
llm: bool | None = Field(default=False)
|
||||
text_embedding: bool | None = Field(default=False)
|
||||
rerank: bool | None = Field(default=False)
|
||||
tts: bool | None = Field(default=False)
|
||||
speech2text: bool | None = Field(default=False)
|
||||
moderation: bool | None = Field(default=False)
|
||||
|
||||
class Node(BaseModel):
|
||||
enabled: bool | None = Field(default=False)
|
||||
|
||||
class Endpoint(BaseModel):
|
||||
enabled: bool | None = Field(default=False)
|
||||
|
||||
class Storage(BaseModel):
|
||||
enabled: bool | None = Field(default=False)
|
||||
size: int = Field(ge=1024, le=1073741824, default=1048576)
|
||||
|
||||
tool: Tool | None = Field(default=None)
|
||||
model: Model | None = Field(default=None)
|
||||
node: Node | None = Field(default=None)
|
||||
endpoint: Endpoint | None = Field(default=None)
|
||||
storage: Storage | None = Field(default=None)
|
||||
|
||||
permission: Permission | None = Field(default=None)
|
||||
|
||||
|
||||
class PluginCategory(StrEnum):
|
||||
Tool = auto()
|
||||
Model = auto()
|
||||
Extension = auto()
|
||||
AgentStrategy = "agent-strategy"
|
||||
Datasource = "datasource"
|
||||
Trigger = "trigger"
|
||||
|
||||
|
||||
class PluginDeclaration(BaseModel):
|
||||
class Plugins(BaseModel):
|
||||
tools: list[str] | None = Field(default_factory=list[str])
|
||||
models: list[str] | None = Field(default_factory=list[str])
|
||||
endpoints: list[str] | None = Field(default_factory=list[str])
|
||||
datasources: list[str] | None = Field(default_factory=list[str])
|
||||
triggers: list[str] | None = Field(default_factory=list[str])
|
||||
|
||||
class Meta(BaseModel):
|
||||
minimum_dify_version: str | None = Field(default=None)
|
||||
version: str | None = Field(default=None)
|
||||
|
||||
@field_validator("minimum_dify_version")
|
||||
@classmethod
|
||||
def validate_minimum_dify_version(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return v
|
||||
try:
|
||||
Version(v)
|
||||
return v
|
||||
except InvalidVersion as e:
|
||||
raise ValueError(f"Invalid version format: {v}") from e
|
||||
|
||||
version: str = Field(...)
|
||||
author: str | None = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$")
|
||||
name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$")
|
||||
description: I18nObject
|
||||
icon: str
|
||||
icon_dark: str | None = Field(default=None)
|
||||
label: I18nObject
|
||||
category: PluginCategory
|
||||
created_at: datetime.datetime
|
||||
resource: PluginResourceRequirements
|
||||
plugins: Plugins
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
repo: str | None = Field(default=None)
|
||||
verified: bool = Field(default=False)
|
||||
tool: ToolProviderEntity | None = None
|
||||
model: ProviderEntity | None = None
|
||||
endpoint: EndpointProviderDeclaration | None = None
|
||||
agent_strategy: AgentStrategyProviderEntity | None = None
|
||||
datasource: DatasourceProviderEntity | None = None
|
||||
trigger: TriggerProviderEntity | None = None
|
||||
meta: Meta
|
||||
|
||||
@field_validator("version")
|
||||
@classmethod
|
||||
def validate_version(cls, v: str) -> str:
|
||||
try:
|
||||
Version(v)
|
||||
return v
|
||||
except InvalidVersion as e:
|
||||
raise ValueError(f"Invalid version format: {v}") from e
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_category(cls, values: dict):
|
||||
# auto detect category
|
||||
if values.get("tool"):
|
||||
values["category"] = PluginCategory.Tool
|
||||
elif values.get("model"):
|
||||
values["category"] = PluginCategory.Model
|
||||
elif values.get("datasource"):
|
||||
values["category"] = PluginCategory.Datasource
|
||||
elif values.get("agent_strategy"):
|
||||
values["category"] = PluginCategory.AgentStrategy
|
||||
elif values.get("trigger"):
|
||||
values["category"] = PluginCategory.Trigger
|
||||
else:
|
||||
values["category"] = PluginCategory.Extension
|
||||
return values
|
||||
|
||||
|
||||
class PluginInstallation(BasePluginEntity):
|
||||
tenant_id: str
|
||||
endpoints_setups: int
|
||||
endpoints_active: int
|
||||
runtime_type: str
|
||||
source: PluginInstallationSource
|
||||
meta: Mapping[str, Any]
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
version: str
|
||||
checksum: str
|
||||
declaration: PluginDeclaration
|
||||
|
||||
|
||||
class PluginEntity(PluginInstallation):
|
||||
name: str
|
||||
installation_id: str
|
||||
version: str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_plugin_id(self):
|
||||
if self.declaration.tool:
|
||||
self.declaration.tool.plugin_id = self.plugin_id
|
||||
return self
|
||||
|
||||
|
||||
class PluginDependency(BaseModel):
|
||||
class Type(StrEnum):
|
||||
Github = PluginInstallationSource.Github
|
||||
Marketplace = PluginInstallationSource.Marketplace
|
||||
Package = PluginInstallationSource.Package
|
||||
|
||||
class Github(BaseModel):
|
||||
repo: str
|
||||
version: str
|
||||
package: str
|
||||
github_plugin_unique_identifier: str
|
||||
|
||||
@property
|
||||
def plugin_unique_identifier(self) -> str:
|
||||
return self.github_plugin_unique_identifier
|
||||
|
||||
class Marketplace(BaseModel):
|
||||
marketplace_plugin_unique_identifier: str
|
||||
version: str | None = None
|
||||
|
||||
@property
|
||||
def plugin_unique_identifier(self) -> str:
|
||||
return self.marketplace_plugin_unique_identifier
|
||||
|
||||
class Package(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
version: str | None = None
|
||||
|
||||
type: Type
|
||||
value: Github | Marketplace | Package
|
||||
current_identifier: str | None = None
|
||||
|
||||
|
||||
class MissingPluginDependency(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
current_identifier: str | None = None
|
||||
259
dify/api/core/plugin/entities/plugin_daemon.py
Normal file
259
dify/api/core/plugin/entities/plugin_daemon.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import enum
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.agent.plugin_entities import AgentProviderEntityWithPlugin
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.base import BasePluginEntity
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
||||
from core.trigger.entities.entities import TriggerProviderEntity
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
|
||||
class PluginDaemonBasicResponse(BaseModel, Generic[T]):
|
||||
"""
|
||||
Basic response from plugin daemon.
|
||||
"""
|
||||
|
||||
code: int
|
||||
message: str
|
||||
data: T | None = None
|
||||
|
||||
|
||||
class InstallPluginMessage(BaseModel):
|
||||
"""
|
||||
Message for installing a plugin.
|
||||
"""
|
||||
|
||||
class Event(StrEnum):
|
||||
Info = "info"
|
||||
Done = "done"
|
||||
Error = "error"
|
||||
|
||||
event: Event
|
||||
data: str
|
||||
|
||||
|
||||
class PluginToolProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
declaration: ToolProviderEntityWithPlugin
|
||||
|
||||
|
||||
class PluginDatasourceProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
is_authorized: bool = False
|
||||
declaration: DatasourceProviderEntityWithPlugin
|
||||
|
||||
|
||||
class PluginAgentProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
declaration: AgentProviderEntityWithPlugin
|
||||
meta: PluginDeclaration.Meta
|
||||
|
||||
|
||||
class PluginBasicBooleanResponse(BaseModel):
|
||||
"""
|
||||
Basic boolean response from plugin daemon.
|
||||
"""
|
||||
|
||||
result: bool
|
||||
credentials: dict | None = None
|
||||
|
||||
|
||||
class PluginModelSchemaEntity(BaseModel):
|
||||
model_schema: AIModelEntity = Field(description="The model schema.")
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class PluginModelProviderEntity(BaseModel):
|
||||
id: str = Field(description="ID")
|
||||
created_at: datetime = Field(description="The created at time of the model provider.")
|
||||
updated_at: datetime = Field(description="The updated at time of the model provider.")
|
||||
provider: str = Field(description="The provider of the model.")
|
||||
tenant_id: str = Field(description="The tenant ID.")
|
||||
plugin_unique_identifier: str = Field(description="The plugin unique identifier.")
|
||||
plugin_id: str = Field(description="The plugin ID.")
|
||||
declaration: ProviderEntity = Field(description="The declaration of the model provider.")
|
||||
|
||||
|
||||
class PluginTextEmbeddingNumTokensResponse(BaseModel):
|
||||
"""
|
||||
Response for number of tokens.
|
||||
"""
|
||||
|
||||
num_tokens: list[int] = Field(description="The number of tokens.")
|
||||
|
||||
|
||||
class PluginLLMNumTokensResponse(BaseModel):
|
||||
"""
|
||||
Response for number of tokens.
|
||||
"""
|
||||
|
||||
num_tokens: int = Field(description="The number of tokens.")
|
||||
|
||||
|
||||
class PluginStringResultResponse(BaseModel):
|
||||
result: str = Field(description="The result of the string.")
|
||||
|
||||
|
||||
class PluginVoiceEntity(BaseModel):
|
||||
name: str = Field(description="The name of the voice.")
|
||||
value: str = Field(description="The value of the voice.")
|
||||
|
||||
|
||||
class PluginVoicesResponse(BaseModel):
|
||||
voices: list[PluginVoiceEntity] = Field(description="The result of the voices.")
|
||||
|
||||
|
||||
class PluginDaemonError(BaseModel):
|
||||
"""
|
||||
Error from plugin daemon.
|
||||
"""
|
||||
|
||||
error_type: str
|
||||
message: str
|
||||
|
||||
|
||||
class PluginDaemonInnerError(Exception):
|
||||
code: int
|
||||
message: str
|
||||
|
||||
def __init__(self, code: int, message: str):
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class PluginInstallTaskStatus(StrEnum):
|
||||
Pending = "pending"
|
||||
Running = "running"
|
||||
Success = "success"
|
||||
Failed = "failed"
|
||||
|
||||
|
||||
class PluginInstallTaskPluginStatus(BaseModel):
|
||||
plugin_unique_identifier: str = Field(description="The plugin unique identifier of the install task.")
|
||||
plugin_id: str = Field(description="The plugin ID of the install task.")
|
||||
status: PluginInstallTaskStatus = Field(description="The status of the install task.")
|
||||
message: str = Field(description="The message of the install task.")
|
||||
icon: str = Field(description="The icon of the plugin.")
|
||||
labels: I18nObject = Field(description="The labels of the plugin.")
|
||||
|
||||
|
||||
class PluginInstallTask(BasePluginEntity):
|
||||
status: PluginInstallTaskStatus = Field(description="The status of the install task.")
|
||||
total_plugins: int = Field(description="The total number of plugins to be installed.")
|
||||
completed_plugins: int = Field(description="The number of plugins that have been installed.")
|
||||
plugins: list[PluginInstallTaskPluginStatus] = Field(description="The status of the plugins.")
|
||||
|
||||
|
||||
class PluginInstallTaskStartResponse(BaseModel):
|
||||
all_installed: bool = Field(description="Whether all plugins are installed.")
|
||||
task_id: str = Field(description="The ID of the install task.")
|
||||
|
||||
|
||||
class PluginVerification(BaseModel):
|
||||
"""
|
||||
Verification of the plugin.
|
||||
"""
|
||||
|
||||
class AuthorizedCategory(StrEnum):
|
||||
Langgenius = "langgenius"
|
||||
Partner = "partner"
|
||||
Community = "community"
|
||||
|
||||
authorized_category: AuthorizedCategory = Field(description="The authorized category of the plugin.")
|
||||
|
||||
|
||||
class PluginDecodeResponse(BaseModel):
|
||||
unique_identifier: str = Field(description="The unique identifier of the plugin.")
|
||||
manifest: PluginDeclaration
|
||||
verification: PluginVerification | None = Field(default=None, description="Basic verification information")
|
||||
|
||||
|
||||
class PluginOAuthAuthorizationUrlResponse(BaseModel):
|
||||
authorization_url: str = Field(description="The URL of the authorization.")
|
||||
|
||||
|
||||
class PluginOAuthCredentialsResponse(BaseModel):
|
||||
metadata: Mapping[str, Any] = Field(
|
||||
default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc."
|
||||
)
|
||||
expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.")
|
||||
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")
|
||||
|
||||
|
||||
class PluginListResponse(BaseModel):
|
||||
list: list[PluginEntity]
|
||||
total: int
|
||||
|
||||
|
||||
class PluginDynamicSelectOptionsResponse(BaseModel):
|
||||
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
|
||||
|
||||
|
||||
class PluginTriggerProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
declaration: TriggerProviderEntity
|
||||
|
||||
|
||||
class CredentialType(enum.StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = "oauth2"
|
||||
UNAUTHORIZED = "unauthorized"
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
return "API KEY"
|
||||
elif self == CredentialType.OAUTH2:
|
||||
return "AUTH"
|
||||
elif self == CredentialType.UNAUTHORIZED:
|
||||
return "UNAUTHORIZED"
|
||||
else:
|
||||
return self.value.replace("-", " ").upper()
|
||||
|
||||
def is_editable(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
def is_validate_allowed(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [item.value for item in cls]
|
||||
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name in {"api-key", "api_key"}:
|
||||
return cls.API_KEY
|
||||
elif type_name in {"oauth2", "oauth"}:
|
||||
return cls.OAUTH2
|
||||
elif type_name == "unauthorized":
|
||||
return cls.UNAUTHORIZED
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
|
||||
class PluginReadmeResponse(BaseModel):
|
||||
content: str = Field(description="The readme of the plugin.")
|
||||
language: str = Field(description="The language of the readme.")
|
||||
284
dify/api/core/plugin/entities/request.py
Normal file
284
dify/api/core/plugin/entities/request.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import binascii
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import Response
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.utils.http_parser import deserialize_response
|
||||
from core.workflow.nodes.parameter_extractor.entities import (
|
||||
ModelConfig as ParameterExtractorModelConfig,
|
||||
)
|
||||
from core.workflow.nodes.parameter_extractor.entities import (
|
||||
ParameterConfig,
|
||||
)
|
||||
from core.workflow.nodes.question_classifier.entities import (
|
||||
ClassConfig,
|
||||
)
|
||||
from core.workflow.nodes.question_classifier.entities import (
|
||||
ModelConfig as QuestionClassifierModelConfig,
|
||||
)
|
||||
|
||||
|
||||
class InvokeCredentials(BaseModel):
|
||||
tool_credentials: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of tool provider to credential id, used to store the credential id for the tool provider.",
|
||||
)
|
||||
|
||||
|
||||
class PluginInvokeContext(BaseModel):
|
||||
credentials: InvokeCredentials | None = Field(
|
||||
default_factory=InvokeCredentials,
|
||||
description="Credentials context for the plugin invocation or backward invocation.",
|
||||
)
|
||||
|
||||
|
||||
class RequestInvokeTool(BaseModel):
|
||||
"""
|
||||
Request to invoke a tool
|
||||
"""
|
||||
|
||||
tool_type: Literal["builtin", "workflow", "api", "mcp"]
|
||||
provider: str
|
||||
tool: str
|
||||
tool_parameters: dict
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
class BaseRequestInvokeModel(BaseModel):
|
||||
provider: str
|
||||
model: str
|
||||
model_type: ModelType
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class RequestInvokeLLM(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke LLM
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.LLM
|
||||
mode: str
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
prompt_messages: list[PromptMessage] = Field(default_factory=list)
|
||||
tools: list[PromptMessageTool] | None = Field(default_factory=list[PromptMessageTool])
|
||||
stop: list[str] | None = Field(default_factory=list[str])
|
||||
stream: bool | None = False
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_validator("prompt_messages", mode="before")
|
||||
@classmethod
|
||||
def convert_prompt_messages(cls, v):
|
||||
if not isinstance(v, list):
|
||||
raise ValueError("prompt_messages must be a list")
|
||||
|
||||
for i in range(len(v)):
|
||||
if v[i]["role"] == PromptMessageRole.USER:
|
||||
v[i] = UserPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.ASSISTANT:
|
||||
v[i] = AssistantPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.SYSTEM:
|
||||
v[i] = SystemPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.TOOL:
|
||||
v[i] = ToolPromptMessage.model_validate(v[i])
|
||||
else:
|
||||
v[i] = PromptMessage.model_validate(v[i])
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class RequestInvokeLLMWithStructuredOutput(RequestInvokeLLM):
|
||||
"""
|
||||
Request to invoke LLM with structured output
|
||||
"""
|
||||
|
||||
structured_output_schema: dict[str, Any] = Field(
|
||||
default_factory=dict, description="The schema of the structured output in JSON schema format"
|
||||
)
|
||||
|
||||
|
||||
class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke text embedding
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.TEXT_EMBEDDING
|
||||
texts: list[str]
|
||||
|
||||
|
||||
class RequestInvokeRerank(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke rerank
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.RERANK
|
||||
query: str
|
||||
docs: list[str]
|
||||
score_threshold: float
|
||||
top_n: int
|
||||
|
||||
|
||||
class RequestInvokeTTS(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke TTS
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.TTS
|
||||
content_text: str
|
||||
voice: str
|
||||
|
||||
|
||||
class RequestInvokeSpeech2Text(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke speech2text
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.SPEECH2TEXT
|
||||
file: bytes
|
||||
|
||||
@field_validator("file", mode="before")
|
||||
@classmethod
|
||||
def convert_file(cls, v):
|
||||
# hex string to bytes
|
||||
if isinstance(v, str):
|
||||
return bytes.fromhex(v)
|
||||
else:
|
||||
raise ValueError("file must be a hex string")
|
||||
|
||||
|
||||
class RequestInvokeModeration(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke moderation
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.MODERATION
|
||||
text: str
|
||||
|
||||
|
||||
class RequestInvokeParameterExtractorNode(BaseModel):
|
||||
"""
|
||||
Request to invoke parameter extractor node
|
||||
"""
|
||||
|
||||
parameters: list[ParameterConfig]
|
||||
model: ParameterExtractorModelConfig
|
||||
instruction: str
|
||||
query: str
|
||||
|
||||
|
||||
class RequestInvokeQuestionClassifierNode(BaseModel):
|
||||
"""
|
||||
Request to invoke question classifier node
|
||||
"""
|
||||
|
||||
query: str
|
||||
model: QuestionClassifierModelConfig
|
||||
classes: list[ClassConfig]
|
||||
instruction: str
|
||||
|
||||
|
||||
class RequestInvokeApp(BaseModel):
|
||||
"""
|
||||
Request to invoke app
|
||||
"""
|
||||
|
||||
app_id: str
|
||||
inputs: dict[str, Any]
|
||||
query: str | None = None
|
||||
response_mode: Literal["blocking", "streaming"]
|
||||
conversation_id: str | None = None
|
||||
user: str | None = None
|
||||
files: list[dict] = Field(default_factory=list)
|
||||
|
||||
|
||||
class RequestInvokeEncrypt(BaseModel):
|
||||
"""
|
||||
Request to encryption
|
||||
"""
|
||||
|
||||
opt: Literal["encrypt", "decrypt", "clear"]
|
||||
namespace: Literal["endpoint"]
|
||||
identity: str
|
||||
data: dict = Field(default_factory=dict)
|
||||
config: list[BasicProviderConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class RequestInvokeSummary(BaseModel):
|
||||
"""
|
||||
Request to summary
|
||||
"""
|
||||
|
||||
text: str
|
||||
instruction: str
|
||||
|
||||
|
||||
class RequestRequestUploadFile(BaseModel):
|
||||
"""
|
||||
Request to upload file
|
||||
"""
|
||||
|
||||
filename: str
|
||||
mimetype: str
|
||||
|
||||
|
||||
class RequestFetchAppInfo(BaseModel):
|
||||
"""
|
||||
Request to fetch app info
|
||||
"""
|
||||
|
||||
app_id: str
|
||||
|
||||
|
||||
class TriggerInvokeEventResponse(BaseModel):
|
||||
variables: Mapping[str, Any] = Field(default_factory=dict)
|
||||
cancelled: bool = Field(default=False)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)
|
||||
|
||||
@field_validator("variables", mode="before")
|
||||
@classmethod
|
||||
def convert_variables(cls, v):
|
||||
if isinstance(v, str):
|
||||
return json.loads(v)
|
||||
else:
|
||||
return v
|
||||
|
||||
|
||||
class TriggerSubscriptionResponse(BaseModel):
|
||||
subscription: dict[str, Any]
|
||||
|
||||
|
||||
class TriggerValidateProviderCredentialsResponse(BaseModel):
|
||||
result: bool
|
||||
|
||||
|
||||
class TriggerDispatchResponse(BaseModel):
|
||||
user_id: str
|
||||
events: list[str]
|
||||
response: Response
|
||||
payload: Mapping[str, Any] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)
|
||||
|
||||
@field_validator("response", mode="before")
|
||||
@classmethod
|
||||
def convert_response(cls, v: str):
|
||||
try:
|
||||
return deserialize_response(binascii.unhexlify(v.encode()))
|
||||
except Exception as e:
|
||||
raise ValueError("Failed to deserialize response from hex string") from e
|
||||
117
dify/api/core/plugin/impl/agent.py
Normal file
117
dify/api/core/plugin/impl/agent.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginAgentProviderEntity,
|
||||
)
|
||||
from core.plugin.entities.request import PluginInvokeContext
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||
from models.provider_ids import GenericProviderID
|
||||
|
||||
|
||||
class PluginAgentClient(BasePluginClient):
|
||||
def fetch_agent_strategy_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]:
|
||||
"""
|
||||
Fetch agent providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for strategy in declaration.get("strategies", []):
|
||||
strategy["identity"]["provider"] = provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/agent_strategies",
|
||||
list[PluginAgentProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for strategy in provider.declaration.strategies:
|
||||
strategy.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_agent_strategy_provider(self, tenant_id: str, provider: str) -> PluginAgentProviderEntity:
|
||||
"""
|
||||
Fetch tool provider for the given tenant and plugin.
|
||||
"""
|
||||
agent_provider_id = GenericProviderID(provider)
|
||||
|
||||
def transformer(json_response: dict[str, Any]):
|
||||
# skip if error occurs
|
||||
if json_response.get("data") is None or json_response.get("data", {}).get("declaration") is None:
|
||||
return json_response
|
||||
|
||||
for strategy in json_response.get("data", {}).get("declaration", {}).get("strategies", []):
|
||||
strategy["identity"]["provider"] = agent_provider_id.provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/agent_strategy",
|
||||
PluginAgentProviderEntity,
|
||||
params={"provider": agent_provider_id.provider_name, "plugin_id": agent_provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for strategy in response.declaration.strategies:
|
||||
strategy.identity.provider = response.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
agent_provider: str,
|
||||
agent_strategy: str,
|
||||
agent_params: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
context: PluginInvokeContext | None = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
|
||||
"""
|
||||
|
||||
agent_provider_id = GenericProviderID(agent_provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/agent_strategy/invoke",
|
||||
AgentInvokeMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"app_id": app_id,
|
||||
"message_id": message_id,
|
||||
"context": context.model_dump() if context else {},
|
||||
"data": {
|
||||
"agent_strategy_provider": agent_provider_id.provider_name,
|
||||
"agent_strategy": agent_strategy,
|
||||
"agent_strategy_params": agent_params,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": agent_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
return merge_blob_chunks(response)
|
||||
22
dify/api/core/plugin/impl/asset.py
Normal file
22
dify/api/core/plugin/impl/asset.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
class PluginAssetManager(BasePluginClient):
|
||||
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
|
||||
"""
|
||||
Fetch an asset by id.
|
||||
"""
|
||||
response = self._request(method="GET", path=f"plugin/{tenant_id}/asset/{id}")
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"can not found asset {id}")
|
||||
return response.content
|
||||
|
||||
def extract_asset(self, tenant_id: str, plugin_unique_identifier: str, filename: str) -> bytes:
|
||||
response = self._request(
|
||||
method="GET",
|
||||
path=f"plugin/{tenant_id}/extract-asset/",
|
||||
params={"plugin_unique_identifier": plugin_unique_identifier, "file_path": filename},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"can not found asset {plugin_unique_identifier}, {str(response.status_code)}")
|
||||
return response.content
|
||||
336
dify/api/core/plugin/impl/base.py
Normal file
336
dify/api/core/plugin/impl/base.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.plugin.endpoint.exc import EndpointSetupFailedError
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError
|
||||
from core.plugin.impl.exc import (
|
||||
PluginDaemonBadRequestError,
|
||||
PluginDaemonInternalServerError,
|
||||
PluginDaemonNotFoundError,
|
||||
PluginDaemonUnauthorizedError,
|
||||
PluginInvokeError,
|
||||
PluginNotFoundError,
|
||||
PluginPermissionDeniedError,
|
||||
PluginUniqueIdentifierError,
|
||||
)
|
||||
from core.trigger.errors import (
|
||||
EventIgnoreError,
|
||||
TriggerInvokeError,
|
||||
TriggerPluginInvokeError,
|
||||
TriggerProviderCredentialValidationError,
|
||||
)
|
||||
|
||||
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
|
||||
_plugin_daemon_timeout_config = cast(
|
||||
float | httpx.Timeout | None,
|
||||
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
|
||||
)
|
||||
plugin_daemon_request_timeout: httpx.Timeout | None
|
||||
if _plugin_daemon_timeout_config is None:
|
||||
plugin_daemon_request_timeout = None
|
||||
elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout):
|
||||
plugin_daemon_request_timeout = _plugin_daemon_timeout_config
|
||||
else:
|
||||
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict[str, Any] | list[Any] | bool | str))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasePluginClient:
|
||||
def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
data: bytes | dict[str, Any] | str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
files: dict[str, Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API.
|
||||
"""
|
||||
url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files)
|
||||
|
||||
request_kwargs: dict[str, Any] = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
"params": params,
|
||||
"files": files,
|
||||
"timeout": plugin_daemon_request_timeout,
|
||||
}
|
||||
if isinstance(prepared_data, dict):
|
||||
request_kwargs["data"] = prepared_data
|
||||
elif prepared_data is not None:
|
||||
request_kwargs["content"] = prepared_data
|
||||
|
||||
try:
|
||||
response = httpx.request(**request_kwargs)
|
||||
except httpx.RequestError:
|
||||
logger.exception("Request to Plugin Daemon Service failed")
|
||||
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
|
||||
|
||||
return response
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
path: str,
|
||||
headers: dict[str, str] | None,
|
||||
data: bytes | dict[str, Any] | str | None,
|
||||
params: dict[str, Any] | None,
|
||||
files: dict[str, Any] | None,
|
||||
) -> tuple[str, dict[str, str], bytes | dict[str, Any] | str | None, dict[str, Any] | None, dict[str, Any] | None]:
|
||||
url = plugin_daemon_inner_api_baseurl / path
|
||||
prepared_headers = dict(headers or {})
|
||||
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
|
||||
|
||||
prepared_data: bytes | dict[str, Any] | str | None = (
|
||||
data if isinstance(data, (bytes, str, dict)) or data is None else None
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
if prepared_headers.get("Content-Type") == "application/json":
|
||||
prepared_data = json.dumps(data)
|
||||
else:
|
||||
prepared_data = data
|
||||
|
||||
return str(url), prepared_headers, prepared_data, params, files
|
||||
|
||||
def _stream_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
data: bytes | dict[str, Any] | None = None,
|
||||
files: dict[str, Any] | None = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Make a stream request to the plugin daemon inner API
|
||||
"""
|
||||
url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files)
|
||||
|
||||
stream_kwargs: dict[str, Any] = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
"params": params,
|
||||
"files": files,
|
||||
"timeout": plugin_daemon_request_timeout,
|
||||
}
|
||||
if isinstance(prepared_data, dict):
|
||||
stream_kwargs["data"] = prepared_data
|
||||
elif prepared_data is not None:
|
||||
stream_kwargs["content"] = prepared_data
|
||||
|
||||
try:
|
||||
with httpx.stream(**stream_kwargs) as response:
|
||||
for raw_line in response.iter_lines():
|
||||
if not raw_line:
|
||||
continue
|
||||
line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line
|
||||
line = line.strip()
|
||||
if line.startswith("data:"):
|
||||
line = line[5:].strip()
|
||||
if line:
|
||||
yield line
|
||||
except httpx.RequestError:
|
||||
logger.exception("Stream request to Plugin Daemon Service failed")
|
||||
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
|
||||
|
||||
def _stream_request_with_model(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type_: type[T],
|
||||
headers: dict[str, str] | None = None,
|
||||
data: bytes | dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
files: dict[str, Any] | None = None,
|
||||
) -> Generator[T, None, None]:
|
||||
"""
|
||||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||
"""
|
||||
for line in self._stream_request(method, path, params, headers, data, files):
|
||||
yield type_(**json.loads(line)) # type: ignore
|
||||
|
||||
def _request_with_model(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type_: type[T],
|
||||
headers: dict[str, str] | None = None,
|
||||
data: bytes | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
files: dict[str, Any] | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
return type_(**response.json()) # type: ignore[return-value]
|
||||
|
||||
def _request_with_plugin_daemon_response(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type_: type[T],
|
||||
headers: dict[str, str] | None = None,
|
||||
data: bytes | dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
files: dict[str, Any] | None = None,
|
||||
transformer: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
try:
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path)
|
||||
raise e
|
||||
except Exception as e:
|
||||
msg = f"Failed to request plugin daemon, url: {path}"
|
||||
logger.exception("Failed to request plugin daemon, url: %s", path)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
try:
|
||||
json_response = response.json()
|
||||
if transformer:
|
||||
json_response = transformer(json_response)
|
||||
# https://stackoverflow.com/questions/59634937/variable-foo-class-is-not-valid-as-type-but-why
|
||||
rep = PluginDaemonBasicResponse[type_].model_validate(json_response) # type: ignore
|
||||
except Exception:
|
||||
msg = (
|
||||
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type_.__name__)}],"
|
||||
f" url: {path}"
|
||||
)
|
||||
logger.exception(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
if rep.code != 0:
|
||||
try:
|
||||
error = PluginDaemonError.model_validate(json.loads(rep.message))
|
||||
except Exception:
|
||||
raise ValueError(f"{rep.message}, code: {rep.code}")
|
||||
|
||||
self._handle_plugin_daemon_error(error.error_type, error.message)
|
||||
if rep.data is None:
|
||||
frame = inspect.currentframe()
|
||||
raise ValueError(f"got empty data from plugin daemon: {frame.f_lineno if frame else 'unknown'}")
|
||||
|
||||
return rep.data
|
||||
|
||||
def _request_with_plugin_daemon_response_stream(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type_: type[T],
|
||||
headers: dict[str, str] | None = None,
|
||||
data: bytes | dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
files: dict[str, Any] | None = None,
|
||||
) -> Generator[T, None, None]:
|
||||
"""
|
||||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||
"""
|
||||
for line in self._stream_request(method, path, params, headers, data, files):
|
||||
try:
|
||||
rep = PluginDaemonBasicResponse[type_].model_validate_json(line) # type: ignore
|
||||
except (ValueError, TypeError):
|
||||
# TODO modify this when line_data has code and message
|
||||
try:
|
||||
line_data = json.loads(line)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(line)
|
||||
# If the dictionary contains the `error` key, use its value as the argument
|
||||
# for `ValueError`.
|
||||
# Otherwise, use the `line` to provide better contextual information about the error.
|
||||
raise ValueError(line_data.get("error", line))
|
||||
|
||||
if rep.code != 0:
|
||||
if rep.code == -500:
|
||||
try:
|
||||
error = PluginDaemonError.model_validate(json.loads(rep.message))
|
||||
except Exception:
|
||||
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
|
||||
|
||||
logger.error("Error in stream response for plugin %s", rep.__dict__)
|
||||
self._handle_plugin_daemon_error(error.error_type, error.message)
|
||||
raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}")
|
||||
if rep.data is None:
|
||||
frame = inspect.currentframe()
|
||||
raise ValueError(f"got empty data from plugin daemon: {frame.f_lineno if frame else 'unknown'}")
|
||||
yield rep.data
|
||||
|
||||
def _handle_plugin_daemon_error(self, error_type: str, message: str):
|
||||
"""
|
||||
handle the error from plugin daemon
|
||||
"""
|
||||
match error_type:
|
||||
case PluginDaemonInnerError.__name__:
|
||||
raise PluginDaemonInnerError(code=-500, message=message)
|
||||
case PluginInvokeError.__name__:
|
||||
error_object = json.loads(message)
|
||||
invoke_error_type = error_object.get("error_type")
|
||||
args = error_object.get("args")
|
||||
match invoke_error_type:
|
||||
case InvokeRateLimitError.__name__:
|
||||
raise InvokeRateLimitError(description=args.get("description"))
|
||||
case InvokeAuthorizationError.__name__:
|
||||
raise InvokeAuthorizationError(description=args.get("description"))
|
||||
case InvokeBadRequestError.__name__:
|
||||
raise InvokeBadRequestError(description=args.get("description"))
|
||||
case InvokeConnectionError.__name__:
|
||||
raise InvokeConnectionError(description=args.get("description"))
|
||||
case InvokeServerUnavailableError.__name__:
|
||||
raise InvokeServerUnavailableError(description=args.get("description"))
|
||||
case CredentialsValidateFailedError.__name__:
|
||||
raise CredentialsValidateFailedError(error_object.get("message"))
|
||||
case EndpointSetupFailedError.__name__:
|
||||
raise EndpointSetupFailedError(error_object.get("message"))
|
||||
case TriggerProviderCredentialValidationError.__name__:
|
||||
raise TriggerProviderCredentialValidationError(error_object.get("message"))
|
||||
case TriggerPluginInvokeError.__name__:
|
||||
raise TriggerPluginInvokeError(description=error_object.get("description"))
|
||||
case TriggerInvokeError.__name__:
|
||||
raise TriggerInvokeError(error_object.get("message"))
|
||||
case EventIgnoreError.__name__:
|
||||
raise EventIgnoreError(description=error_object.get("description"))
|
||||
case _:
|
||||
raise PluginInvokeError(description=message)
|
||||
case PluginDaemonInternalServerError.__name__:
|
||||
raise PluginDaemonInternalServerError(description=message)
|
||||
case PluginDaemonBadRequestError.__name__:
|
||||
raise PluginDaemonBadRequestError(description=message)
|
||||
case PluginDaemonNotFoundError.__name__:
|
||||
raise PluginDaemonNotFoundError(description=message)
|
||||
case PluginUniqueIdentifierError.__name__:
|
||||
raise PluginUniqueIdentifierError(description=message)
|
||||
case PluginNotFoundError.__name__:
|
||||
raise PluginNotFoundError(description=message)
|
||||
case PluginDaemonUnauthorizedError.__name__:
|
||||
raise PluginDaemonUnauthorizedError(description=message)
|
||||
case PluginPermissionDeniedError.__name__:
|
||||
raise PluginPermissionDeniedError(description=message)
|
||||
case _:
|
||||
raise Exception(f"got unknown error from plugin daemon: {error_type}, message: {message}")
|
||||
374
dify/api/core/plugin/impl/datasource.py
Normal file
374
dify/api/core/plugin/impl/datasource.py
Normal file
@@ -0,0 +1,374 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceMessage,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDocumentPagesMessage,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
OnlineDriveBrowseFilesResponse,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
WebsiteCrawlMessage,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
PluginDatasourceProviderEntity,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.schemas.resolver import resolve_dify_schema_refs
|
||||
from models.provider_ids import DatasourceProviderID, GenericProviderID
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class PluginDatasourceManager(BasePluginClient):
|
||||
def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
|
||||
"""
|
||||
Fetch datasource providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
if json_response.get("data"):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for datasource in declaration.get("datasources", []):
|
||||
datasource["identity"]["provider"] = provider_name
|
||||
# resolve refs
|
||||
if datasource.get("output_schema"):
|
||||
datasource["output_schema"] = resolve_dify_schema_refs(datasource["output_schema"])
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasources",
|
||||
list[PluginDatasourceProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
local_file_datasource_provider = PluginDatasourceProviderEntity.model_validate(
|
||||
self._get_local_file_datasource_provider()
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
|
||||
all_response = [local_file_datasource_provider] + response
|
||||
|
||||
for provider in all_response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in provider.declaration.datasources:
|
||||
tool.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return all_response
|
||||
|
||||
def fetch_installed_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
|
||||
"""
|
||||
Fetch datasource providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
if json_response.get("data"):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for datasource in declaration.get("datasources", []):
|
||||
datasource["identity"]["provider"] = provider_name
|
||||
# resolve refs
|
||||
if datasource.get("output_schema"):
|
||||
datasource["output_schema"] = resolve_dify_schema_refs(datasource["output_schema"])
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasources",
|
||||
list[PluginDatasourceProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in provider.declaration.datasources:
|
||||
tool.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_datasource_provider(self, tenant_id: str, provider_id: str) -> PluginDatasourceProviderEntity:
|
||||
"""
|
||||
Fetch datasource provider for the given tenant and plugin.
|
||||
"""
|
||||
if provider_id == "langgenius/file/file":
|
||||
return PluginDatasourceProviderEntity.model_validate(self._get_local_file_datasource_provider())
|
||||
|
||||
tool_provider_id = DatasourceProviderID(provider_id)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for datasource in data.get("declaration", {}).get("datasources", []):
|
||||
datasource["identity"]["provider"] = tool_provider_id.provider_name
|
||||
if datasource.get("output_schema"):
|
||||
datasource["output_schema"] = resolve_dify_schema_refs(datasource["output_schema"])
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasource",
|
||||
PluginDatasourceProviderEntity,
|
||||
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for datasource in response.declaration.datasources:
|
||||
datasource.identity.provider = response.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def get_website_crawl(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: Mapping[str, Any],
|
||||
provider_type: str,
|
||||
) -> Generator[WebsiteCrawlMessage, None, None]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
return self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl",
|
||||
WebsiteCrawlMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"datasource_parameters": datasource_parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def get_online_document_pages(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: Mapping[str, Any],
|
||||
provider_type: str,
|
||||
) -> Generator[OnlineDocumentPagesMessage, None, None]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
return self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages",
|
||||
OnlineDocumentPagesMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"datasource_parameters": datasource_parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def get_online_document_page_content(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
return self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content",
|
||||
DatasourceMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"page": datasource_parameters.model_dump(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def online_drive_browse_files(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
request: OnlineDriveBrowseFilesRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/online_drive_browse_files",
|
||||
OnlineDriveBrowseFilesResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"request": request.model_dump(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
yield from response
|
||||
|
||||
def online_drive_download_file(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
request: OnlineDriveDownloadFileRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/online_drive_download_file",
|
||||
DatasourceMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"request": request.model_dump(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
yield from response
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
# datasource_provider_id = GenericProviderID(provider_id)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
return False
|
||||
|
||||
def _get_local_file_datasource_provider(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": "langgenius/file/file",
|
||||
"plugin_id": "langgenius/file",
|
||||
"provider": "file",
|
||||
"plugin_unique_identifier": "langgenius/file:0.0.1@dify",
|
||||
"declaration": {
|
||||
"identity": {
|
||||
"author": "langgenius",
|
||||
"name": "file",
|
||||
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
"icon": "https://assets.dify.ai/images/File%20Upload.svg",
|
||||
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
},
|
||||
"credentials_schema": [],
|
||||
"provider_type": "local_file",
|
||||
"datasources": [
|
||||
{
|
||||
"identity": {
|
||||
"author": "langgenius",
|
||||
"name": "upload-file",
|
||||
"provider": "file",
|
||||
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
},
|
||||
"parameters": [],
|
||||
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
17
dify/api/core/plugin/impl/debugging.py
Normal file
17
dify/api/core/plugin/impl/debugging.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
class PluginDebuggingClient(BasePluginClient):
|
||||
def get_debugging_key(self, tenant_id: str) -> str:
|
||||
"""
|
||||
Get the debugging key for the given tenant.
|
||||
"""
|
||||
|
||||
class Response(BaseModel):
|
||||
key: str
|
||||
|
||||
response = self._request_with_plugin_daemon_response("POST", f"plugin/{tenant_id}/debugging/key", Response)
|
||||
|
||||
return response.key
|
||||
47
dify/api/core/plugin/impl/dynamic_select.py
Normal file
47
dify/api/core/plugin/impl/dynamic_select.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginDynamicSelectOptionsResponse
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from models.provider_ids import GenericProviderID
|
||||
|
||||
|
||||
class DynamicSelectClient(BasePluginClient):
|
||||
def fetch_dynamic_select_options(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
action: str,
|
||||
credentials: Mapping[str, Any],
|
||||
credential_type: str,
|
||||
parameter: str,
|
||||
) -> PluginDynamicSelectOptionsResponse:
|
||||
"""
|
||||
Fetch dynamic select options for a plugin parameter.
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/dynamic_select/fetch_parameter_options",
|
||||
PluginDynamicSelectOptionsResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": GenericProviderID(provider).provider_name,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"provider_action": action,
|
||||
"parameter": parameter,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for options in response:
|
||||
return options
|
||||
|
||||
raise ValueError(f"Plugin service returned no options for parameter '{parameter}' in provider '{provider}'")
|
||||
116
dify/api/core/plugin/impl/endpoint.py
Normal file
116
dify/api/core/plugin/impl/endpoint.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from core.plugin.entities.endpoint import EndpointEntityWithInstance
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
class PluginEndpointClient(BasePluginClient):
|
||||
def create_endpoint(
|
||||
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict
|
||||
) -> bool:
|
||||
"""
|
||||
Create an endpoint for the given plugin.
|
||||
|
||||
Errors will be raised if any error occurs.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/endpoint/setup",
|
||||
bool,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"plugin_unique_identifier": plugin_unique_identifier,
|
||||
"settings": settings,
|
||||
"name": name,
|
||||
},
|
||||
)
|
||||
|
||||
def list_endpoints(self, tenant_id: str, user_id: str, page: int, page_size: int):
|
||||
"""
|
||||
List all endpoints for the given tenant and user.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/endpoint/list",
|
||||
list[EndpointEntityWithInstance],
|
||||
params={"page": page, "page_size": page_size},
|
||||
)
|
||||
|
||||
def list_endpoints_for_single_plugin(self, tenant_id: str, user_id: str, plugin_id: str, page: int, page_size: int):
|
||||
"""
|
||||
List all endpoints for the given tenant, user and plugin.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/endpoint/list/plugin",
|
||||
list[EndpointEntityWithInstance],
|
||||
params={"plugin_id": plugin_id, "page": page, "page_size": page_size},
|
||||
)
|
||||
|
||||
def update_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
|
||||
"""
|
||||
Update the settings of the given endpoint.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/endpoint/update",
|
||||
bool,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"endpoint_id": endpoint_id,
|
||||
"name": name,
|
||||
"settings": settings,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def delete_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
"""
|
||||
Delete the given endpoint.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/endpoint/remove",
|
||||
bool,
|
||||
data={
|
||||
"endpoint_id": endpoint_id,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def enable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
"""
|
||||
Enable the given endpoint.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/endpoint/enable",
|
||||
bool,
|
||||
data={
|
||||
"endpoint_id": endpoint_id,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def disable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
"""
|
||||
Disable the given endpoint.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/endpoint/disable",
|
||||
bool,
|
||||
data={
|
||||
"endpoint_id": endpoint_id,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
85
dify/api/core/plugin/impl/exc.py
Normal file
85
dify/api/core/plugin/impl/exc.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from extensions.ext_logging import get_request_id
|
||||
|
||||
|
||||
class PluginDaemonError(Exception):
|
||||
"""Base class for all plugin daemon errors."""
|
||||
|
||||
def __init__(self, description: str):
|
||||
self.description = description
|
||||
|
||||
def __str__(self) -> str:
|
||||
# returns the class name and description
|
||||
return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}"
|
||||
|
||||
|
||||
class PluginDaemonInternalError(PluginDaemonError):
|
||||
pass
|
||||
|
||||
|
||||
class PluginDaemonClientSideError(PluginDaemonError):
|
||||
pass
|
||||
|
||||
|
||||
class PluginDaemonInternalServerError(PluginDaemonInternalError):
|
||||
description: str = "Internal Server Error"
|
||||
|
||||
|
||||
class PluginDaemonUnauthorizedError(PluginDaemonInternalError):
|
||||
description: str = "Unauthorized"
|
||||
|
||||
|
||||
class PluginDaemonNotFoundError(PluginDaemonInternalError):
|
||||
description: str = "Not Found"
|
||||
|
||||
|
||||
class PluginDaemonBadRequestError(PluginDaemonClientSideError):
|
||||
description: str = "Bad Request"
|
||||
|
||||
|
||||
class PluginInvokeError(PluginDaemonClientSideError, ValueError):
|
||||
description: str = "Invoke Error"
|
||||
|
||||
def _get_error_object(self) -> Mapping:
|
||||
try:
|
||||
return TypeAdapter(Mapping).validate_json(self.description)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def get_error_type(self) -> str:
|
||||
return self._get_error_object().get("error_type", "unknown")
|
||||
|
||||
def get_error_message(self) -> str:
|
||||
try:
|
||||
return self._get_error_object().get("message", "unknown")
|
||||
except Exception:
|
||||
return self.description
|
||||
|
||||
def to_user_friendly_error(self, plugin_name: str = "currently running plugin") -> str:
|
||||
"""
|
||||
Convert the error to a user-friendly error message.
|
||||
|
||||
:param plugin_name: The name of the plugin that caused the error.
|
||||
:return: A user-friendly error message.
|
||||
"""
|
||||
return (
|
||||
f"An error occurred in the {plugin_name}, "
|
||||
f"please contact the author of {plugin_name} for help, "
|
||||
f"error type: {self.get_error_type()}, "
|
||||
f"error details: {self.get_error_message()}"
|
||||
)
|
||||
|
||||
|
||||
class PluginUniqueIdentifierError(PluginDaemonClientSideError):
|
||||
description: str = "Unique Identifier Error"
|
||||
|
||||
|
||||
class PluginNotFoundError(PluginDaemonClientSideError):
|
||||
description: str = "Plugin Not Found"
|
||||
|
||||
|
||||
class PluginPermissionDeniedError(PluginDaemonClientSideError):
|
||||
description: str = "Permission Denied"
|
||||
531
dify/api/core/plugin/impl/model.py
Normal file
531
dify/api/core/plugin/impl/model.py
Normal file
@@ -0,0 +1,531 @@
|
||||
import binascii
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import IO
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
PluginDaemonInnerError,
|
||||
PluginLLMNumTokensResponse,
|
||||
PluginModelProviderEntity,
|
||||
PluginModelSchemaEntity,
|
||||
PluginStringResultResponse,
|
||||
PluginTextEmbeddingNumTokensResponse,
|
||||
PluginVoicesResponse,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
class PluginModelClient(BasePluginClient):
|
||||
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
|
||||
"""
|
||||
Fetch model providers for the given tenant.
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/models",
|
||||
list[PluginModelProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
)
|
||||
return response
|
||||
|
||||
def get_model_schema(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/schema",
|
||||
PluginModelSchemaEntity,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.model_schema
|
||||
|
||||
return None
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
if resp.credentials and isinstance(resp.credentials, dict):
|
||||
credentials.update(resp.credentials)
|
||||
|
||||
return resp.result
|
||||
|
||||
return False
|
||||
|
||||
def validate_model_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
if resp.credentials and isinstance(resp.credentials, dict):
|
||||
credentials.update(resp.credentials)
|
||||
|
||||
return resp.result
|
||||
|
||||
return False
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Invoke llm
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
|
||||
type_=LLMResultChunk,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "llm",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"prompt_messages": prompt_messages,
|
||||
"model_parameters": model_parameters,
|
||||
"tools": tools,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
yield from response
|
||||
except PluginDaemonInnerError as e:
|
||||
raise ValueError(e.message + str(e.code))
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for llm
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
|
||||
type_=PluginLLMNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"prompt_messages": prompt_messages,
|
||||
"tools": tools,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.num_tokens
|
||||
|
||||
return 0
|
||||
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
input_type: str,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
|
||||
type_=TextEmbeddingResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"texts": texts,
|
||||
"input_type": input_type,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("Failed to invoke text embedding")
|
||||
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get number of tokens for text embedding
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
|
||||
type_=PluginTextEmbeddingNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"texts": texts,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.num_tokens
|
||||
|
||||
return []
|
||||
|
||||
def invoke_rerank(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
|
||||
type_=RerankResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "rerank",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"query": query,
|
||||
"docs": docs,
|
||||
"score_threshold": score_threshold,
|
||||
"top_n": top_n,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("Failed to invoke rerank")
|
||||
|
||||
def invoke_tts(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
) -> Generator[bytes, None, None]:
|
||||
"""
|
||||
Invoke tts
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "tts",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"tenant_id": tenant_id,
|
||||
"content_text": content_text,
|
||||
"voice": voice,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
for result in response:
|
||||
hex_str = result.result
|
||||
yield binascii.unhexlify(hex_str)
|
||||
except PluginDaemonInnerError as e:
|
||||
raise ValueError(e.message + str(e.code))
|
||||
|
||||
def get_tts_model_voices(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
language: str | None = None,
|
||||
):
|
||||
"""
|
||||
Get tts model voices
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
|
||||
type_=PluginVoicesResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "tts",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"language": language,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
voices = []
|
||||
for voice in resp.voices:
|
||||
voices.append({"name": voice.name, "value": voice.value})
|
||||
|
||||
return voices
|
||||
|
||||
return []
|
||||
|
||||
def invoke_speech_to_text(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
file: IO[bytes],
|
||||
) -> str:
|
||||
"""
|
||||
Invoke speech to text
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "speech2text",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"file": binascii.hexlify(file.read()).decode(),
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
raise ValueError("Failed to invoke speech to text")
|
||||
|
||||
def invoke_moderation(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
text: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Invoke moderation
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
|
||||
type_=PluginBasicBooleanResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "moderation",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"text": text,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
raise ValueError("Failed to invoke moderation")
|
||||
150
dify/api/core/plugin/impl/oauth.py
Normal file
150
dify/api/core/plugin/impl/oauth.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import binascii
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from werkzeug import Request
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginOAuthAuthorizationUrlResponse, PluginOAuthCredentialsResponse
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
class OAuthHandler(BasePluginClient):
|
||||
def get_authorization_url(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
redirect_uri: str,
|
||||
system_credentials: Mapping[str, Any],
|
||||
) -> PluginOAuthAuthorizationUrlResponse:
|
||||
try:
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
|
||||
PluginOAuthAuthorizationUrlResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"redirect_uri": redirect_uri,
|
||||
"system_credentials": system_credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
for resp in response:
|
||||
return resp
|
||||
raise ValueError("No response received from plugin daemon for authorization URL request.")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error getting authorization URL: {e}")
|
||||
|
||||
def get_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
redirect_uri: str,
|
||||
system_credentials: Mapping[str, Any],
|
||||
request: Request,
|
||||
) -> PluginOAuthCredentialsResponse:
|
||||
"""
|
||||
Get credentials from the given request.
|
||||
"""
|
||||
|
||||
try:
|
||||
# encode request to raw http request
|
||||
raw_request_bytes = self._convert_request_to_raw_data(request)
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
|
||||
PluginOAuthCredentialsResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"redirect_uri": redirect_uri,
|
||||
"system_credentials": system_credentials,
|
||||
# for json serialization
|
||||
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
for resp in response:
|
||||
return resp
|
||||
raise ValueError("No response received from plugin daemon for authorization URL request.")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error getting credentials: {e}")
|
||||
|
||||
def refresh_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
redirect_uri: str,
|
||||
system_credentials: Mapping[str, Any],
|
||||
credentials: Mapping[str, Any],
|
||||
) -> PluginOAuthCredentialsResponse:
|
||||
try:
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials",
|
||||
PluginOAuthCredentialsResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"redirect_uri": redirect_uri,
|
||||
"system_credentials": system_credentials,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
for resp in response:
|
||||
return resp
|
||||
raise ValueError("No response received from plugin daemon for refresh credentials request.")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error refreshing credentials: {e}")
|
||||
|
||||
def _convert_request_to_raw_data(self, request: Request) -> bytes:
|
||||
"""
|
||||
Convert a Request object to raw HTTP data.
|
||||
|
||||
Args:
|
||||
request: The Request object to convert.
|
||||
|
||||
Returns:
|
||||
The raw HTTP data as bytes.
|
||||
"""
|
||||
# Start with the request line
|
||||
method = request.method
|
||||
path = request.full_path
|
||||
protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
|
||||
raw_data = f"{method} {path} {protocol}\r\n".encode()
|
||||
|
||||
# Add headers
|
||||
for header_name, header_value in request.headers.items():
|
||||
raw_data += f"{header_name}: {header_value}\r\n".encode()
|
||||
|
||||
# Add empty line to separate headers from body
|
||||
raw_data += b"\r\n"
|
||||
|
||||
# Add body if exists
|
||||
body = request.get_data(as_text=False)
|
||||
if body:
|
||||
raw_data += body
|
||||
|
||||
return raw_data
|
||||
300
dify/api/core/plugin/impl/plugin.py
Normal file
300
dify/api/core/plugin/impl/plugin.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
from core.plugin.entities.bundle import PluginBundleDependency
|
||||
from core.plugin.entities.plugin import (
|
||||
MissingPluginDependency,
|
||||
PluginDeclaration,
|
||||
PluginEntity,
|
||||
PluginInstallation,
|
||||
PluginInstallationSource,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginDecodeResponse,
|
||||
PluginInstallTask,
|
||||
PluginInstallTaskStartResponse,
|
||||
PluginListResponse,
|
||||
PluginReadmeResponse,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from models.provider_ids import GenericProviderID
|
||||
|
||||
|
||||
class PluginInstaller(BasePluginClient):
|
||||
def fetch_plugin_readme(self, tenant_id: str, plugin_unique_identifier: str, language: str) -> str:
|
||||
"""
|
||||
Fetch plugin readme
|
||||
"""
|
||||
try:
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/fetch/readme",
|
||||
PluginReadmeResponse,
|
||||
params={
|
||||
"tenant_id": tenant_id,
|
||||
"plugin_unique_identifier": plugin_unique_identifier,
|
||||
"language": language,
|
||||
},
|
||||
)
|
||||
return response.content
|
||||
except HTTPError as e:
|
||||
message = e.args[0]
|
||||
if "404" in message:
|
||||
return ""
|
||||
raise e
|
||||
|
||||
def fetch_plugin_by_identifier(
|
||||
self,
|
||||
tenant_id: str,
|
||||
identifier: str,
|
||||
) -> bool:
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/fetch/identifier",
|
||||
bool,
|
||||
params={"plugin_unique_identifier": identifier},
|
||||
)
|
||||
|
||||
def list_plugins(self, tenant_id: str) -> list[PluginEntity]:
|
||||
result = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/list",
|
||||
PluginListResponse,
|
||||
params={"page": 1, "page_size": 256, "response_type": "paged"},
|
||||
)
|
||||
return result.list
|
||||
|
||||
def list_plugins_with_total(self, tenant_id: str, page: int, page_size: int) -> PluginListResponse:
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/list",
|
||||
PluginListResponse,
|
||||
params={"page": page, "page_size": page_size, "response_type": "paged"},
|
||||
)
|
||||
|
||||
def upload_pkg(
|
||||
self,
|
||||
tenant_id: str,
|
||||
pkg: bytes,
|
||||
verify_signature: bool = False,
|
||||
) -> PluginDecodeResponse:
|
||||
"""
|
||||
Upload a plugin package and return the plugin unique identifier.
|
||||
"""
|
||||
body = {
|
||||
"dify_pkg": ("dify_pkg", pkg, "application/octet-stream"),
|
||||
}
|
||||
|
||||
data = {
|
||||
"verify_signature": "true" if verify_signature else "false",
|
||||
}
|
||||
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/install/upload/package",
|
||||
PluginDecodeResponse,
|
||||
files=body,
|
||||
data=data,
|
||||
)
|
||||
|
||||
def upload_bundle(
|
||||
self,
|
||||
tenant_id: str,
|
||||
bundle: bytes,
|
||||
verify_signature: bool = False,
|
||||
) -> Sequence[PluginBundleDependency]:
|
||||
"""
|
||||
Upload a plugin bundle and return the dependencies.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/install/upload/bundle",
|
||||
list[PluginBundleDependency],
|
||||
files={"dify_bundle": ("dify_bundle", bundle, "application/octet-stream")},
|
||||
data={"verify_signature": "true" if verify_signature else "false"},
|
||||
)
|
||||
|
||||
def install_from_identifiers(
|
||||
self,
|
||||
tenant_id: str,
|
||||
identifiers: Sequence[str],
|
||||
source: PluginInstallationSource,
|
||||
metas: list[dict],
|
||||
) -> PluginInstallTaskStartResponse:
|
||||
"""
|
||||
Install a plugin from an identifier.
|
||||
"""
|
||||
# exception will be raised if the request failed
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/install/identifiers",
|
||||
PluginInstallTaskStartResponse,
|
||||
data={
|
||||
"plugin_unique_identifiers": identifiers,
|
||||
"source": source,
|
||||
"metas": metas,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def fetch_plugin_installation_tasks(self, tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]:
|
||||
"""
|
||||
Fetch plugin installation tasks.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/install/tasks",
|
||||
list[PluginInstallTask],
|
||||
params={"page": page, "page_size": page_size},
|
||||
)
|
||||
|
||||
def fetch_plugin_installation_task(self, tenant_id: str, task_id: str) -> PluginInstallTask:
|
||||
"""
|
||||
Fetch a plugin installation task.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/install/tasks/{task_id}",
|
||||
PluginInstallTask,
|
||||
)
|
||||
|
||||
def delete_plugin_installation_task(self, tenant_id: str, task_id: str) -> bool:
|
||||
"""
|
||||
Delete a plugin installation task.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/install/tasks/{task_id}/delete",
|
||||
bool,
|
||||
)
|
||||
|
||||
def delete_all_plugin_installation_task_items(self, tenant_id: str) -> bool:
|
||||
"""
|
||||
Delete all plugin installation task items.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/install/tasks/delete_all",
|
||||
bool,
|
||||
)
|
||||
|
||||
def delete_plugin_installation_task_item(self, tenant_id: str, task_id: str, identifier: str) -> bool:
|
||||
"""
|
||||
Delete a plugin installation task item.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/install/tasks/{task_id}/delete/{identifier}",
|
||||
bool,
|
||||
)
|
||||
|
||||
def fetch_plugin_manifest(self, tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
|
||||
"""
|
||||
Fetch a plugin manifest.
|
||||
"""
|
||||
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/fetch/manifest",
|
||||
PluginDeclaration,
|
||||
params={"plugin_unique_identifier": plugin_unique_identifier},
|
||||
)
|
||||
|
||||
def decode_plugin_from_identifier(self, tenant_id: str, plugin_unique_identifier: str) -> PluginDecodeResponse:
|
||||
"""
|
||||
Decode a plugin from an identifier.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/decode/from_identifier",
|
||||
PluginDecodeResponse,
|
||||
data={"plugin_unique_identifier": plugin_unique_identifier},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def fetch_plugin_installation_by_ids(
|
||||
self, tenant_id: str, plugin_ids: Sequence[str]
|
||||
) -> Sequence[PluginInstallation]:
|
||||
"""
|
||||
Fetch plugin installations by ids.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/installation/fetch/batch",
|
||||
list[PluginInstallation],
|
||||
data={"plugin_ids": plugin_ids},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def fetch_missing_dependencies(
|
||||
self, tenant_id: str, plugin_unique_identifiers: list[str]
|
||||
) -> list[MissingPluginDependency]:
|
||||
"""
|
||||
Fetch missing dependencies
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/installation/missing",
|
||||
list[MissingPluginDependency],
|
||||
data={"plugin_unique_identifiers": plugin_unique_identifiers},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def uninstall(self, tenant_id: str, plugin_installation_id: str) -> bool:
|
||||
"""
|
||||
Uninstall a plugin.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/uninstall",
|
||||
bool,
|
||||
data={
|
||||
"plugin_installation_id": plugin_installation_id,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def upgrade_plugin(
|
||||
self,
|
||||
tenant_id: str,
|
||||
original_plugin_unique_identifier: str,
|
||||
new_plugin_unique_identifier: str,
|
||||
source: PluginInstallationSource,
|
||||
meta: dict,
|
||||
) -> PluginInstallTaskStartResponse:
|
||||
"""
|
||||
Upgrade a plugin.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/install/upgrade",
|
||||
PluginInstallTaskStartResponse,
|
||||
data={
|
||||
"original_plugin_unique_identifier": original_plugin_unique_identifier,
|
||||
"new_plugin_unique_identifier": new_plugin_unique_identifier,
|
||||
"source": source,
|
||||
"meta": meta,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def check_tools_existence(self, tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
|
||||
"""
|
||||
Check if the tools exist
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/tools/check_existence",
|
||||
list[bool],
|
||||
data={
|
||||
"provider_ids": [
|
||||
{
|
||||
"plugin_id": provider_id.plugin_id,
|
||||
"provider_name": provider_id.provider_name,
|
||||
}
|
||||
for provider_id in provider_ids
|
||||
]
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
230
dify/api/core/plugin/impl/tool.py
Normal file
230
dify/api/core/plugin/impl/tool.py
Normal file
@@ -0,0 +1,230 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||
from core.schemas.resolver import resolve_dify_schema_refs
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from models.provider_ids import GenericProviderID, ToolProviderID
|
||||
|
||||
|
||||
class PluginToolManager(BasePluginClient):
|
||||
def fetch_tool_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]:
|
||||
"""
|
||||
Fetch tool providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for tool in declaration.get("tools", []):
|
||||
tool["identity"]["provider"] = provider_name
|
||||
# resolve refs
|
||||
if tool.get("output_schema"):
|
||||
tool["output_schema"] = resolve_dify_schema_refs(tool["output_schema"])
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/tools",
|
||||
list[PluginToolProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in provider.declaration.tools:
|
||||
tool.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
|
||||
"""
|
||||
Fetch tool provider for the given tenant and plugin.
|
||||
"""
|
||||
tool_provider_id = ToolProviderID(provider)
|
||||
|
||||
def transformer(json_response: dict[str, Any]):
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for tool in data.get("declaration", {}).get("tools", []):
|
||||
tool["identity"]["provider"] = tool_provider_id.provider_name
|
||||
# resolve refs
|
||||
if tool.get("output_schema"):
|
||||
tool["output_schema"] = resolve_dify_schema_refs(tool["output_schema"])
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/tool",
|
||||
PluginToolProviderEntity,
|
||||
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in response.declaration.tools:
|
||||
tool.identity.provider = response.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
tool_provider: str,
|
||||
tool_name: str,
|
||||
credentials: dict[str, Any],
|
||||
credential_type: CredentialType,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the tool with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
tool_provider_id = GenericProviderID(tool_provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/tool/invoke",
|
||||
ToolInvokeMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"app_id": app_id,
|
||||
"message_id": message_id,
|
||||
"data": {
|
||||
"provider": tool_provider_id.provider_name,
|
||||
"tool": tool_name,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"tool_parameters": tool_parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": tool_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
return merge_blob_chunks(response)
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
tool_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/tool/validate_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": tool_provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": tool_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
return False
|
||||
|
||||
def validate_datasource_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the datasource
|
||||
"""
|
||||
tool_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": tool_provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": tool_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
return False
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
credentials: dict[str, Any],
|
||||
tool: str,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get the runtime parameters of the tool
|
||||
"""
|
||||
tool_provider_id = GenericProviderID(provider)
|
||||
|
||||
class RuntimeParametersResponse(BaseModel):
|
||||
parameters: list[ToolParameter]
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/tool/get_runtime_parameters",
|
||||
RuntimeParametersResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"app_id": app_id,
|
||||
"message_id": message_id,
|
||||
"data": {
|
||||
"provider": tool_provider_id.provider_name,
|
||||
"tool": tool,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": tool_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.parameters
|
||||
|
||||
return []
|
||||
305
dify/api/core/plugin/impl/trigger.py
Normal file
305
dify/api/core/plugin/impl/trigger.py
Normal file
@@ -0,0 +1,305 @@
|
||||
import binascii
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Request
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginTriggerProviderEntity
|
||||
from core.plugin.entities.request import (
|
||||
TriggerDispatchResponse,
|
||||
TriggerInvokeEventResponse,
|
||||
TriggerSubscriptionResponse,
|
||||
TriggerValidateProviderCredentialsResponse,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.http_parser import serialize_request
|
||||
from core.trigger.entities.entities import Subscription
|
||||
from models.provider_ids import TriggerProviderID
|
||||
|
||||
|
||||
class PluginTriggerClient(BasePluginClient):
|
||||
def fetch_trigger_providers(self, tenant_id: str) -> list[PluginTriggerProviderEntity]:
|
||||
"""
|
||||
Fetch trigger providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_id = provider.get("plugin_id") + "/" + provider.get("provider")
|
||||
for event in declaration.get("events", []):
|
||||
event["identity"]["provider"] = provider_id
|
||||
|
||||
return json_response
|
||||
|
||||
response: list[PluginTriggerProviderEntity] = self._request_with_plugin_daemon_response(
|
||||
method="GET",
|
||||
path=f"plugin/{tenant_id}/management/triggers",
|
||||
type_=list[PluginTriggerProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each trigger to plugin_id/provider_name
|
||||
for event in provider.declaration.events:
|
||||
event.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_trigger_provider(self, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderEntity:
|
||||
"""
|
||||
Fetch trigger provider for the given tenant and plugin.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for event in data.get("declaration", {}).get("events", []):
|
||||
event["identity"]["provider"] = str(provider_id)
|
||||
|
||||
return json_response
|
||||
|
||||
response: PluginTriggerProviderEntity = self._request_with_plugin_daemon_response(
|
||||
method="GET",
|
||||
path=f"plugin/{tenant_id}/management/trigger",
|
||||
type_=PluginTriggerProviderEntity,
|
||||
params={"provider": provider_id.provider_name, "plugin_id": provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = str(provider_id)
|
||||
|
||||
# override the provider name for each trigger to plugin_id/provider_name
|
||||
for event in response.declaration.events:
|
||||
event.identity.provider = str(provider_id)
|
||||
|
||||
return response
|
||||
|
||||
def invoke_trigger_event(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
event_name: str,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
request: Request,
|
||||
parameters: Mapping[str, Any],
|
||||
subscription: Subscription,
|
||||
payload: Mapping[str, Any],
|
||||
) -> TriggerInvokeEventResponse:
|
||||
"""
|
||||
Invoke a trigger with the given parameters.
|
||||
"""
|
||||
provider_id = TriggerProviderID(provider)
|
||||
response: Generator[TriggerInvokeEventResponse, None, None] = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/trigger/invoke_event",
|
||||
type_=TriggerInvokeEventResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider_id.provider_name,
|
||||
"event": event_name,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"subscription": subscription.model_dump(),
|
||||
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
|
||||
"parameters": parameters,
|
||||
"payload": payload,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for invoke trigger")
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: Mapping[str, str]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate the credentials of the trigger provider.
|
||||
"""
|
||||
provider_id = TriggerProviderID(provider)
|
||||
response: Generator[TriggerValidateProviderCredentialsResponse, None, None] = (
|
||||
self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/trigger/validate_credentials",
|
||||
type_=TriggerValidateProviderCredentialsResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
raise ValueError("No response received from plugin daemon for validate provider credentials")
|
||||
|
||||
def dispatch_event(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
subscription: Mapping[str, Any],
|
||||
request: Request,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
) -> TriggerDispatchResponse:
|
||||
"""
|
||||
Dispatch an event to triggers.
|
||||
"""
|
||||
provider_id = TriggerProviderID(provider)
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/trigger/dispatch_event",
|
||||
type_=TriggerDispatchResponse,
|
||||
data={
|
||||
"data": {
|
||||
"provider": provider_id.provider_name,
|
||||
"subscription": subscription,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for dispatch event")
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
endpoint: str,
|
||||
parameters: Mapping[str, Any],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Subscribe to a trigger.
|
||||
"""
|
||||
provider_id = TriggerProviderID(provider)
|
||||
response: Generator[TriggerSubscriptionResponse, None, None] = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/trigger/subscribe",
|
||||
type_=TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"endpoint": endpoint,
|
||||
"parameters": parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for subscribe")
|
||||
|
||||
def unsubscribe(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Subscription,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Unsubscribe from a trigger.
|
||||
"""
|
||||
provider_id = TriggerProviderID(provider)
|
||||
response: Generator[TriggerSubscriptionResponse, None, None] = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/trigger/unsubscribe",
|
||||
type_=TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider_id.provider_name,
|
||||
"subscription": subscription.model_dump(),
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for unsubscribe")
|
||||
|
||||
def refresh(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Subscription,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Refresh a trigger subscription.
|
||||
"""
|
||||
provider_id = TriggerProviderID(provider)
|
||||
response: Generator[TriggerSubscriptionResponse, None, None] = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/trigger/refresh",
|
||||
type_=TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider_id.provider_name,
|
||||
"subscription": subscription.model_dump(),
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for refresh")
|
||||
95
dify/api/core/plugin/utils/chunk_merger.py
Normal file
95
dify/api/core/plugin/utils/chunk_merger.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeVar, Union
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
MessageType = TypeVar("MessageType", bound=Union[ToolInvokeMessage, AgentInvokeMessage])
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileChunk:
|
||||
"""
|
||||
Buffer for accumulating file chunks during streaming.
|
||||
"""
|
||||
|
||||
total_length: int
|
||||
bytes_written: int = field(default=0, init=False)
|
||||
data: bytearray = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.data = bytearray(self.total_length)
|
||||
|
||||
|
||||
def merge_blob_chunks(
|
||||
response: Generator[MessageType, None, None],
|
||||
max_file_size: int = 30 * 1024 * 1024,
|
||||
max_chunk_size: int = 8192,
|
||||
) -> Generator[MessageType, None, None]:
|
||||
"""
|
||||
Merge streaming blob chunks into complete blob messages.
|
||||
|
||||
This function processes a stream of plugin invoke messages, accumulating
|
||||
BLOB_CHUNK messages by their ID until the final chunk is received,
|
||||
then yielding a single complete BLOB message.
|
||||
|
||||
Args:
|
||||
response: Generator yielding messages that may include blob chunks
|
||||
max_file_size: Maximum allowed file size in bytes (default: 30MB)
|
||||
max_chunk_size: Maximum allowed chunk size in bytes (default: 8KB)
|
||||
|
||||
Yields:
|
||||
Messages from the response stream, with blob chunks merged into complete blobs
|
||||
|
||||
Raises:
|
||||
ValueError: If file size exceeds max_file_size or chunk size exceeds max_chunk_size
|
||||
"""
|
||||
files: dict[str, FileChunk] = {}
|
||||
|
||||
for resp in response:
|
||||
if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
|
||||
assert isinstance(resp.message, ToolInvokeMessage.BlobChunkMessage)
|
||||
# Get blob chunk information
|
||||
chunk_id = resp.message.id
|
||||
total_length = resp.message.total_length
|
||||
blob_data = resp.message.blob
|
||||
is_end = resp.message.end
|
||||
|
||||
# Initialize buffer for this file if it doesn't exist
|
||||
if chunk_id not in files:
|
||||
files[chunk_id] = FileChunk(total_length)
|
||||
|
||||
# Check if file is too large (before appending)
|
||||
if files[chunk_id].bytes_written + len(blob_data) > max_file_size:
|
||||
# Delete the file if it's too large
|
||||
del files[chunk_id]
|
||||
raise ValueError(f"File is too large which reached the limit of {max_file_size / 1024 / 1024}MB")
|
||||
|
||||
# Check if single chunk is too large
|
||||
if len(blob_data) > max_chunk_size:
|
||||
raise ValueError(f"File chunk is too large which reached the limit of {max_chunk_size / 1024}KB")
|
||||
|
||||
# Append the blob data to the buffer
|
||||
files[chunk_id].data[files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data)] = (
|
||||
blob_data
|
||||
)
|
||||
files[chunk_id].bytes_written += len(blob_data)
|
||||
|
||||
# If this is the final chunk, yield a complete blob message
|
||||
if is_end:
|
||||
# Create the appropriate message type based on the response type
|
||||
message_class = type(resp)
|
||||
merged_message = message_class(
|
||||
type=ToolInvokeMessage.MessageType.BLOB,
|
||||
message=ToolInvokeMessage.BlobMessage(
|
||||
blob=bytes(files[chunk_id].data[: files[chunk_id].bytes_written])
|
||||
),
|
||||
meta=resp.meta,
|
||||
)
|
||||
assert isinstance(merged_message, (ToolInvokeMessage, AgentInvokeMessage))
|
||||
yield merged_message # type: ignore
|
||||
# Clean up the buffer
|
||||
del files[chunk_id]
|
||||
else:
|
||||
yield resp
|
||||
21
dify/api/core/plugin/utils/converter.py
Normal file
21
dify/api/core/plugin/utils/converter.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Any
|
||||
|
||||
from core.file.models import File
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
|
||||
|
||||
def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
for parameter_name, parameter in parameters.items():
|
||||
if isinstance(parameter, File):
|
||||
parameters[parameter_name] = parameter.to_plugin_parameter()
|
||||
elif isinstance(parameter, list) and all(isinstance(p, File) for p in parameter):
|
||||
parameters[parameter_name] = []
|
||||
for p in parameter:
|
||||
parameters[parameter_name].append(p.to_plugin_parameter())
|
||||
elif isinstance(parameter, ToolSelector):
|
||||
parameters[parameter_name] = parameter.to_plugin_parameter()
|
||||
elif isinstance(parameter, list) and all(isinstance(p, ToolSelector) for p in parameter):
|
||||
parameters[parameter_name] = []
|
||||
for p in parameter:
|
||||
parameters[parameter_name].append(p.to_plugin_parameter())
|
||||
return parameters
|
||||
163
dify/api/core/plugin/utils/http_parser.py
Normal file
163
dify/api/core/plugin/utils/http_parser.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from io import BytesIO
|
||||
|
||||
from flask import Request, Response
|
||||
from werkzeug.datastructures import Headers
|
||||
|
||||
|
||||
def serialize_request(request: Request) -> bytes:
|
||||
method = request.method
|
||||
path = request.full_path.rstrip("?")
|
||||
raw = f"{method} {path} HTTP/1.1\r\n".encode()
|
||||
|
||||
for name, value in request.headers.items():
|
||||
raw += f"{name}: {value}\r\n".encode()
|
||||
|
||||
raw += b"\r\n"
|
||||
|
||||
body = request.get_data(as_text=False)
|
||||
if body:
|
||||
raw += body
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def deserialize_request(raw_data: bytes) -> Request:
|
||||
header_end = raw_data.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
header_end = raw_data.find(b"\n\n")
|
||||
if header_end == -1:
|
||||
header_data = raw_data
|
||||
body = b""
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 2 :]
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 4 :]
|
||||
|
||||
lines = header_data.split(b"\r\n")
|
||||
if len(lines) == 1 and b"\n" in lines[0]:
|
||||
lines = header_data.split(b"\n")
|
||||
|
||||
if not lines or not lines[0]:
|
||||
raise ValueError("Empty HTTP request")
|
||||
|
||||
request_line = lines[0].decode("utf-8", errors="ignore")
|
||||
parts = request_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid request line: {request_line}")
|
||||
|
||||
method = parts[0]
|
||||
full_path = parts[1]
|
||||
protocol = parts[2] if len(parts) > 2 else "HTTP/1.1"
|
||||
|
||||
if "?" in full_path:
|
||||
path, query_string = full_path.split("?", 1)
|
||||
else:
|
||||
path = full_path
|
||||
query_string = ""
|
||||
|
||||
headers = Headers()
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8", errors="ignore")
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
headers.add(name, value.strip())
|
||||
|
||||
host = headers.get("Host", "localhost")
|
||||
if ":" in host:
|
||||
server_name, server_port = host.rsplit(":", 1)
|
||||
else:
|
||||
server_name = host
|
||||
server_port = "80"
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": method,
|
||||
"PATH_INFO": path,
|
||||
"QUERY_STRING": query_string,
|
||||
"SERVER_NAME": server_name,
|
||||
"SERVER_PORT": server_port,
|
||||
"SERVER_PROTOCOL": protocol,
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
}
|
||||
|
||||
if "Content-Type" in headers:
|
||||
content_type = headers.get("Content-Type")
|
||||
if content_type is not None:
|
||||
environ["CONTENT_TYPE"] = content_type
|
||||
|
||||
if "Content-Length" in headers:
|
||||
content_length = headers.get("Content-Length")
|
||||
if content_length is not None:
|
||||
environ["CONTENT_LENGTH"] = content_length
|
||||
elif body:
|
||||
environ["CONTENT_LENGTH"] = str(len(body))
|
||||
|
||||
for name, value in headers.items():
|
||||
if name.upper() in ("CONTENT-TYPE", "CONTENT-LENGTH"):
|
||||
continue
|
||||
env_name = f"HTTP_{name.upper().replace('-', '_')}"
|
||||
environ[env_name] = value
|
||||
|
||||
return Request(environ)
|
||||
|
||||
|
||||
def serialize_response(response: Response) -> bytes:
|
||||
raw = f"HTTP/1.1 {response.status}\r\n".encode()
|
||||
|
||||
for name, value in response.headers.items():
|
||||
raw += f"{name}: {value}\r\n".encode()
|
||||
|
||||
raw += b"\r\n"
|
||||
|
||||
body = response.get_data(as_text=False)
|
||||
if body:
|
||||
raw += body
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def deserialize_response(raw_data: bytes) -> Response:
|
||||
header_end = raw_data.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
header_end = raw_data.find(b"\n\n")
|
||||
if header_end == -1:
|
||||
header_data = raw_data
|
||||
body = b""
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 2 :]
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 4 :]
|
||||
|
||||
lines = header_data.split(b"\r\n")
|
||||
if len(lines) == 1 and b"\n" in lines[0]:
|
||||
lines = header_data.split(b"\n")
|
||||
|
||||
if not lines or not lines[0]:
|
||||
raise ValueError("Empty HTTP response")
|
||||
|
||||
status_line = lines[0].decode("utf-8", errors="ignore")
|
||||
parts = status_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid status line: {status_line}")
|
||||
|
||||
status_code = int(parts[1])
|
||||
|
||||
response = Response(response=body, status=status_code)
|
||||
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8", errors="ignore")
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
response.headers[name] = value.strip()
|
||||
|
||||
return response
|
||||
Reference in New Issue
Block a user