dify
This commit is contained in:
@@ -0,0 +1,34 @@
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from jinja2 import Template
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
|
||||
|
||||
|
||||
class MockedCodeExecutor:
|
||||
@classmethod
|
||||
def invoke(cls, language: Literal["python3", "javascript", "jinja2"], code: str, inputs: dict):
|
||||
# invoke directly
|
||||
match language:
|
||||
case CodeLanguage.PYTHON3:
|
||||
return {"result": 3}
|
||||
case CodeLanguage.JINJA2:
|
||||
return {"result": Template(code).render(inputs)}
|
||||
case _:
|
||||
raise Exception("Language not supported")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_code_executor_mock(request, monkeypatch: MonkeyPatch):
|
||||
if not MOCK:
|
||||
yield
|
||||
return
|
||||
|
||||
monkeypatch.setattr(CodeExecutor, "execute_workflow_code_template", MockedCodeExecutor.invoke)
|
||||
yield
|
||||
monkeypatch.undo()
|
||||
@@ -0,0 +1,56 @@
|
||||
import os
|
||||
from json import dumps
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
|
||||
|
||||
|
||||
class MockedHttp:
|
||||
@staticmethod
|
||||
def httpx_request(
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Mocked httpx.request
|
||||
"""
|
||||
if url == "http://404.com":
|
||||
response = httpx.Response(status_code=404, request=httpx.Request(method, url), content=b"Not Found")
|
||||
return response
|
||||
|
||||
# get data, files
|
||||
data = kwargs.get("data")
|
||||
files = kwargs.get("files")
|
||||
json = kwargs.get("json")
|
||||
content = kwargs.get("content")
|
||||
if data is not None:
|
||||
resp = dumps(data).encode("utf-8")
|
||||
elif files is not None:
|
||||
resp = dumps(files).encode("utf-8")
|
||||
elif json is not None:
|
||||
resp = dumps(json).encode("utf-8")
|
||||
elif content is not None:
|
||||
resp = content
|
||||
else:
|
||||
resp = b"OK"
|
||||
|
||||
response = httpx.Response(
|
||||
status_code=200, request=httpx.Request(method, url), headers=kwargs.get("headers", {}), content=resp
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_http_mock(request, monkeypatch: MonkeyPatch):
|
||||
if not MOCK:
|
||||
yield
|
||||
return
|
||||
|
||||
monkeypatch.setattr(ssrf_proxy, "make_request", MockedHttp.httpx_request)
|
||||
yield
|
||||
monkeypatch.undo()
|
||||
@@ -0,0 +1,50 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
def get_mocked_fetch_model_config(
|
||||
provider: str,
|
||||
model: str,
|
||||
mode: str,
|
||||
credentials: dict,
|
||||
):
|
||||
model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
configuration=ProviderConfiguration(
|
||||
tenant_id="1",
|
||||
provider=model_provider_factory.get_provider_schema(provider),
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
using_provider_type=ProviderType.CUSTOM,
|
||||
system_configuration=SystemConfiguration(enabled=False),
|
||||
custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)),
|
||||
model_settings=[],
|
||||
),
|
||||
model_type_instance=model_type_instance,
|
||||
)
|
||||
model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model)
|
||||
model_schema = model_provider_factory.get_model_schema(
|
||||
provider=provider,
|
||||
model_type=model_type_instance.model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
assert model_schema is not None
|
||||
model_config = ModelConfigWithCredentialsEntity(
|
||||
model=model,
|
||||
provider=provider,
|
||||
mode=mode,
|
||||
credentials=credentials,
|
||||
parameters={},
|
||||
model_schema=model_schema,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
)
|
||||
|
||||
return MagicMock(return_value=(model_instance, model_config))
|
||||
@@ -0,0 +1,11 @@
|
||||
import pytest
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||
|
||||
CODE_LANGUAGE = "unsupported_language"
|
||||
|
||||
|
||||
def test_unsupported_with_code_template():
|
||||
with pytest.raises(CodeExecutionError) as e:
|
||||
CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={})
|
||||
assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}"
|
||||
@@ -0,0 +1,38 @@
|
||||
from textwrap import dedent
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
|
||||
|
||||
CODE_LANGUAGE = CodeLanguage.JAVASCRIPT
|
||||
|
||||
|
||||
def test_javascript_plain():
|
||||
code = 'console.log("Hello World")'
|
||||
result_message = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code)
|
||||
assert result_message == "Hello World\n"
|
||||
|
||||
|
||||
def test_javascript_json():
|
||||
code = dedent("""
|
||||
obj = {'Hello': 'World'}
|
||||
console.log(JSON.stringify(obj))
|
||||
""")
|
||||
result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code)
|
||||
assert result == '{"Hello":"World"}\n'
|
||||
|
||||
|
||||
def test_javascript_with_code_template():
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CODE_LANGUAGE,
|
||||
code=JavascriptCodeProvider.get_default_code(),
|
||||
inputs={"arg1": "Hello", "arg2": "World"},
|
||||
)
|
||||
assert result == {"result": "HelloWorld"}
|
||||
|
||||
|
||||
def test_javascript_get_runner_script():
|
||||
runner_script = NodeJsTemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(NodeJsTemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(NodeJsTemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(NodeJsTemplateTransformer._result_tag) == 2
|
||||
@@ -0,0 +1,34 @@
|
||||
import base64
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
|
||||
|
||||
CODE_LANGUAGE = CodeLanguage.JINJA2
|
||||
|
||||
|
||||
def test_jinja2():
|
||||
template = "Hello {{template}}"
|
||||
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
|
||||
code = (
|
||||
Jinja2TemplateTransformer.get_runner_script()
|
||||
.replace(Jinja2TemplateTransformer._code_placeholder, template)
|
||||
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
|
||||
)
|
||||
result = CodeExecutor.execute_code(
|
||||
language=CODE_LANGUAGE, preload=Jinja2TemplateTransformer.get_preload_script(), code=code
|
||||
)
|
||||
assert result == "<<RESULT>>Hello World<<RESULT>>\n"
|
||||
|
||||
|
||||
def test_jinja2_with_code_template():
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"}
|
||||
)
|
||||
assert result == {"result": "Hello World"}
|
||||
|
||||
|
||||
def test_jinja2_get_runner_script():
|
||||
runner_script = Jinja2TemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
|
||||
@@ -0,0 +1,36 @@
|
||||
from textwrap import dedent
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
|
||||
|
||||
CODE_LANGUAGE = CodeLanguage.PYTHON3
|
||||
|
||||
|
||||
def test_python3_plain():
|
||||
code = 'print("Hello World")'
|
||||
result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code)
|
||||
assert result == "Hello World\n"
|
||||
|
||||
|
||||
def test_python3_json():
|
||||
code = dedent("""
|
||||
import json
|
||||
print(json.dumps({'Hello': 'World'}))
|
||||
""")
|
||||
result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code)
|
||||
assert result == '{"Hello": "World"}\n'
|
||||
|
||||
|
||||
def test_python3_with_code_template():
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={"arg1": "Hello", "arg2": "World"}
|
||||
)
|
||||
assert result == {"result": "HelloWorld"}
|
||||
|
||||
|
||||
def test_python3_get_runner_script():
|
||||
runner_script = Python3TemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(Python3TemplateTransformer._result_tag) == 2
|
||||
397
dify/api/tests/integration_tests/workflow/nodes/test_code.py
Normal file
397
dify/api/tests/integration_tests/workflow/nodes/test_code.py
Normal file
@@ -0,0 +1,397 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
CODE_MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH
|
||||
|
||||
|
||||
def init_code_node(code_config: dict):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-code-target",
|
||||
"source": "start",
|
||||
"target": "code",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, code_config],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["code", "args1"], 1)
|
||||
variable_pool.add(["code", "args2"], 2)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = CodeNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=code_config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in code_config:
|
||||
node.init_node_data(code_config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_code(setup_code_executor_mock):
|
||||
code = """
|
||||
def main(args1: int, args2: int):
|
||||
return {
|
||||
"result": args1 + args2,
|
||||
}
|
||||
"""
|
||||
# trim first 4 spaces at the beginning of each line
|
||||
code = "\n".join([line[4:] for line in code.split("\n")])
|
||||
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "number",
|
||||
},
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "args1"], 1)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "args2"], 2)
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] == 3
|
||||
assert result.error == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_code_output_validator(setup_code_executor_mock):
|
||||
code = """
|
||||
def main(args1: int, args2: int):
|
||||
return {
|
||||
"result": args1 + args2,
|
||||
}
|
||||
"""
|
||||
# trim first 4 spaces at the beginning of each line
|
||||
code = "\n".join([line[4:] for line in code.split("\n")])
|
||||
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "args1"], 1)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "args2"], 2)
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error == "Output result must be a string, got int instead"
|
||||
|
||||
|
||||
def test_execute_code_output_validator_depth():
|
||||
code = """
|
||||
def main(args1: int, args2: int):
|
||||
return {
|
||||
"result": {
|
||||
"result": args1 + args2,
|
||||
}
|
||||
}
|
||||
"""
|
||||
# trim first 4 spaces at the beginning of each line
|
||||
code = "\n".join([line[4:] for line in code.split("\n")])
|
||||
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"outputs": {
|
||||
"string_validator": {
|
||||
"type": "string",
|
||||
},
|
||||
"number_validator": {
|
||||
"type": "number",
|
||||
},
|
||||
"number_array_validator": {
|
||||
"type": "array[number]",
|
||||
},
|
||||
"string_array_validator": {
|
||||
"type": "array[string]",
|
||||
},
|
||||
"object_validator": {
|
||||
"type": "object",
|
||||
"children": {
|
||||
"result": {
|
||||
"type": "number",
|
||||
},
|
||||
"depth": {
|
||||
"type": "object",
|
||||
"children": {
|
||||
"depth": {
|
||||
"type": "object",
|
||||
"children": {
|
||||
"depth": {
|
||||
"type": "number",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
"number_validator": 1,
|
||||
"string_validator": "1",
|
||||
"number_array_validator": [1, 2, 3, 3.333],
|
||||
"string_array_validator": ["1", "2", "3"],
|
||||
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
|
||||
}
|
||||
|
||||
# validate
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
"number_validator": "1",
|
||||
"string_validator": 1,
|
||||
"number_array_validator": ["1", "2", "3", "3.333"],
|
||||
"string_array_validator": [1, 2, 3],
|
||||
"object_validator": {"result": "1", "depth": {"depth": {"depth": "1"}}},
|
||||
}
|
||||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
"number_validator": 1,
|
||||
"string_validator": (CODE_MAX_STRING_LENGTH + 1) * "1",
|
||||
"number_array_validator": [1, 2, 3, 3.333],
|
||||
"string_array_validator": ["1", "2", "3"],
|
||||
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
|
||||
}
|
||||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
"number_validator": 1,
|
||||
"string_validator": "1",
|
||||
"number_array_validator": [1, 2, 3, 3.333] * 2000,
|
||||
"string_array_validator": ["1", "2", "3"],
|
||||
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
|
||||
}
|
||||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
|
||||
def test_execute_code_output_object_list():
|
||||
code = """
|
||||
def main(args1: int, args2: int):
|
||||
return {
|
||||
"result": {
|
||||
"result": args1 + args2,
|
||||
}
|
||||
}
|
||||
"""
|
||||
# trim first 4 spaces at the beginning of each line
|
||||
code = "\n".join([line[4:] for line in code.split("\n")])
|
||||
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"outputs": {
|
||||
"object_list": {
|
||||
"type": "array[object]",
|
||||
},
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
"object_list": [
|
||||
{
|
||||
"result": 1,
|
||||
},
|
||||
{
|
||||
"result": 2,
|
||||
},
|
||||
{
|
||||
"result": [1, 2, 3],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
# validate
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
"object_list": [
|
||||
{
|
||||
"result": 1,
|
||||
},
|
||||
{
|
||||
"result": 2,
|
||||
},
|
||||
{
|
||||
"result": [1, 2, 3],
|
||||
},
|
||||
1,
|
||||
]
|
||||
}
|
||||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_code_scientific_notation(setup_code_executor_mock):
|
||||
code = """
|
||||
def main():
|
||||
return {
|
||||
"result": -8.0E-5
|
||||
}
|
||||
"""
|
||||
code = "\n".join([line[4:] for line in code.split("\n")])
|
||||
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "number",
|
||||
},
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
723
dify/api/tests/integration_tests/workflow/nodes/test_http.py
Normal file
723
dify/api/tests/integration_tests/workflow/nodes/test_http.py
Normal file
@@ -0,0 +1,723 @@
|
||||
import time
|
||||
import uuid
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
|
||||
|
||||
|
||||
def init_http_node(config: dict):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-next-target",
|
||||
"source": "start",
|
||||
"target": "1",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["a", "args1"], 1)
|
||||
variable_pool.add(["a", "args2"], 2)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = HttpRequestNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in config:
|
||||
node.init_node_data(config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_get(setup_http_mock):
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "basic",
|
||||
"api_key": "ak-xxx",
|
||||
"header": "api-key",
|
||||
},
|
||||
},
|
||||
"headers": "X-Header:123",
|
||||
"params": "A:b",
|
||||
"body": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert "?A=b" in data
|
||||
assert "X-Header: 123" in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_no_auth(setup_http_mock):
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com",
|
||||
"authorization": {
|
||||
"type": "no-auth",
|
||||
"config": None,
|
||||
},
|
||||
"headers": "X-Header:123",
|
||||
"params": "A:b",
|
||||
"body": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert "?A=b" in data
|
||||
assert "X-Header: 123" in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_authorization_header(setup_http_mock):
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "custom",
|
||||
"api_key": "Auth",
|
||||
"header": "X-Auth",
|
||||
},
|
||||
},
|
||||
"headers": "X-Header:123",
|
||||
"params": "A:b",
|
||||
"body": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert "?A=b" in data
|
||||
assert "X-Header: 123" in data
|
||||
# Custom authorization header should be set (may be masked)
|
||||
assert "X-Auth:" in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
|
||||
"""Test: In custom authentication mode, when the api_key is empty, no header should be set."""
|
||||
from core.workflow.nodes.http_request.entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeData,
|
||||
HttpRequestNodeTimeout,
|
||||
)
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
# Create variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="test", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Create node data with custom auth and empty api_key
|
||||
node_data = HttpRequestNodeData(
|
||||
title="http",
|
||||
desc="",
|
||||
url="http://example.com",
|
||||
method="get",
|
||||
authorization=HttpRequestNodeAuthorization(
|
||||
type="api-key",
|
||||
config={
|
||||
"type": "custom",
|
||||
"api_key": "", # Empty api_key
|
||||
"header": "X-Custom-Auth",
|
||||
},
|
||||
),
|
||||
headers="",
|
||||
params="",
|
||||
body=None,
|
||||
ssl_verify=True,
|
||||
)
|
||||
|
||||
# Create executor
|
||||
executor = Executor(
|
||||
node_data=node_data, timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), variable_pool=variable_pool
|
||||
)
|
||||
|
||||
# Get assembled headers
|
||||
headers = executor._assembling_headers()
|
||||
|
||||
# When api_key is empty, the custom header should NOT be set
|
||||
assert "X-Custom-Auth" not in headers
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_bearer_authorization_with_custom_header_ignored(setup_http_mock):
|
||||
"""
|
||||
Test that when switching from custom to bearer authorization,
|
||||
the custom header settings don't interfere with bearer token.
|
||||
This test verifies the fix for issue #23554.
|
||||
"""
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "bearer",
|
||||
"api_key": "test-token",
|
||||
"header": "", # Empty header - should default to Authorization
|
||||
},
|
||||
},
|
||||
"headers": "",
|
||||
"params": "",
|
||||
"body": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
# In bearer mode, should use Authorization header (value is masked with *)
|
||||
assert "Authorization: " in data
|
||||
# Should contain masked Bearer token
|
||||
assert "*" in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_basic_authorization_with_custom_header_ignored(setup_http_mock):
|
||||
"""
|
||||
Test that when switching from custom to basic authorization,
|
||||
the custom header settings don't interfere with basic auth.
|
||||
This test verifies the fix for issue #23554.
|
||||
"""
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "basic",
|
||||
"api_key": "user:pass",
|
||||
"header": "", # Empty header - should default to Authorization
|
||||
},
|
||||
},
|
||||
"headers": "",
|
||||
"params": "",
|
||||
"body": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
# In basic mode, should use Authorization header (value is masked with *)
|
||||
assert "Authorization: " in data
|
||||
# Should contain masked Basic credentials
|
||||
assert "*" in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_authorization_with_empty_api_key(setup_http_mock):
|
||||
"""
|
||||
Test that custom authorization doesn't set header when api_key is empty.
|
||||
This test verifies the fix for issue #23554.
|
||||
"""
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "custom",
|
||||
"api_key": "", # Empty api_key
|
||||
"header": "X-Custom-Auth",
|
||||
},
|
||||
},
|
||||
"headers": "",
|
||||
"params": "",
|
||||
"body": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
# Custom header should NOT be set when api_key is empty
|
||||
assert "X-Custom-Auth:" not in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_template(setup_http_mock):
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com/{{#a.args2#}}",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "basic",
|
||||
"api_key": "ak-xxx",
|
||||
"header": "api-key",
|
||||
},
|
||||
},
|
||||
"headers": "X-Header:123\nX-Header2:{{#a.args2#}}",
|
||||
"params": "A:b\nTemplate:{{#a.args2#}}",
|
||||
"body": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert "?A=b" in data
|
||||
assert "Template=2" in data
|
||||
assert "X-Header: 123" in data
|
||||
assert "X-Header2: 2" in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_json(setup_http_mock):
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "post",
|
||||
"url": "http://example.com",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "basic",
|
||||
"api_key": "ak-xxx",
|
||||
"header": "api-key",
|
||||
},
|
||||
},
|
||||
"headers": "X-Header:123",
|
||||
"params": "A:b",
|
||||
"body": {
|
||||
"type": "json",
|
||||
"data": [
|
||||
{
|
||||
"key": "",
|
||||
"type": "text",
|
||||
"value": '{"a": "{{#a.args1#}}"}',
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert '{"a": "1"}' in data
|
||||
assert "X-Header: 123" in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_x_www_form_urlencoded(setup_http_mock):
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "post",
|
||||
"url": "http://example.com",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "basic",
|
||||
"api_key": "ak-xxx",
|
||||
"header": "api-key",
|
||||
},
|
||||
},
|
||||
"headers": "X-Header:123",
|
||||
"params": "A:b",
|
||||
"body": {
|
||||
"type": "x-www-form-urlencoded",
|
||||
"data": [
|
||||
{
|
||||
"key": "a",
|
||||
"type": "text",
|
||||
"value": "{{#a.args1#}}",
|
||||
},
|
||||
{
|
||||
"key": "b",
|
||||
"type": "text",
|
||||
"value": "{{#a.args2#}}",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert "a=1&b=2" in data
|
||||
assert "X-Header: 123" in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_form_data(setup_http_mock):
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "post",
|
||||
"url": "http://example.com",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "basic",
|
||||
"api_key": "ak-xxx",
|
||||
"header": "api-key",
|
||||
},
|
||||
},
|
||||
"headers": "X-Header:123",
|
||||
"params": "A:b",
|
||||
"body": {
|
||||
"type": "form-data",
|
||||
"data": [
|
||||
{
|
||||
"key": "a",
|
||||
"type": "text",
|
||||
"value": "{{#a.args1#}}",
|
||||
},
|
||||
{
|
||||
"key": "b",
|
||||
"type": "text",
|
||||
"value": "{{#a.args2#}}",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert 'form-data; name="a"' in data
|
||||
assert "1" in data
|
||||
assert 'form-data; name="b"' in data
|
||||
assert "2" in data
|
||||
assert "X-Header: 123" in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_none_data(setup_http_mock):
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "post",
|
||||
"url": "http://example.com",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "basic",
|
||||
"api_key": "ak-xxx",
|
||||
"header": "api-key",
|
||||
},
|
||||
},
|
||||
"headers": "X-Header:123",
|
||||
"params": "A:b",
|
||||
"body": {"type": "none", "data": []},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
assert "X-Header: 123" in data
|
||||
assert "123123123" not in data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_mock_404(setup_http_mock):
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://404.com",
|
||||
"authorization": {
|
||||
"type": "no-auth",
|
||||
"config": None,
|
||||
},
|
||||
"body": None,
|
||||
"params": "",
|
||||
"headers": "X-Header:123",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.outputs is not None
|
||||
resp = result.outputs
|
||||
|
||||
assert resp.get("status_code") == 404
|
||||
assert "Not Found" in resp.get("body", "")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_multi_colons_parse(setup_http_mock):
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com",
|
||||
"authorization": {
|
||||
"type": "no-auth",
|
||||
"config": None,
|
||||
},
|
||||
"params": "Referer:http://example1.com\nRedirect:http://example2.com",
|
||||
"headers": "Referer:http://example3.com\nRedirect:http://example4.com",
|
||||
"body": {
|
||||
"type": "form-data",
|
||||
"data": [
|
||||
{
|
||||
"key": "Referer",
|
||||
"type": "text",
|
||||
"value": "http://example5.com",
|
||||
},
|
||||
{
|
||||
"key": "Redirect",
|
||||
"type": "text",
|
||||
"value": "http://example6.com",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
assert result.outputs is not None
|
||||
|
||||
assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "")
|
||||
assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "")
|
||||
# resp = result.outputs
|
||||
# assert "http://example3.com" == resp.get("headers", {}).get("referer")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_nested_object_variable_selector(setup_http_mock):
|
||||
"""Test variable selector functionality with nested object properties."""
|
||||
# Create independent test setup without affecting other tests
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-next-target",
|
||||
"source": "start",
|
||||
"target": "1",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start", "title": "Start"}, "id": "start"},
|
||||
{
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com/{{#a.args2#}}/{{#a.args3.nested#}}",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "basic",
|
||||
"api_key": "ak-xxx",
|
||||
"header": "api-key",
|
||||
},
|
||||
},
|
||||
"headers": "X-Header:{{#a.args3.nested#}}",
|
||||
"params": "nested_param:{{#a.args3.nested#}}",
|
||||
"body": None,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Create independent variable pool for this test only
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["a", "args1"], 1)
|
||||
variable_pool.add(["a", "args2"], 2)
|
||||
variable_pool.add(["a", "args3"], {"nested": "nested_value"}) # Only for this test
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = HttpRequestNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=graph_config["nodes"][1],
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in graph_config["nodes"][1]:
|
||||
node.init_node_data(graph_config["nodes"][1]["data"])
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
# Verify nested object property is correctly resolved
|
||||
assert "/2/nested_value" in data # URL path should contain resolved nested value
|
||||
assert "X-Header: nested_value" in data # Header should contain nested value
|
||||
assert "nested_param=nested_value" in data # Param should contain nested value
|
||||
308
dify/api/tests/integration_tests/workflow/nodes/test_llm.py
Normal file
308
dify/api/tests/integration_tests/workflow/nodes/test_llm.py
Normal file
@@ -0,0 +1,308 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
|
||||
|
||||
def init_llm_node(config: dict) -> LLMNode:
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-next-target",
|
||||
"source": "start",
|
||||
"target": "llm",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
# Use proper UUIDs for database compatibility
|
||||
tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||
app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
|
||||
workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d"
|
||||
user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e"
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="aaa",
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
files=[],
|
||||
query="what's the weather today?",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["abc", "output"], "sunny")
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = LLMNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in config:
|
||||
node.init_node_data(config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
def test_execute_llm():
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
"prompt_template": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
|
||||
},
|
||||
{"role": "user", "text": "{{#sys.query#}}"},
|
||||
],
|
||||
"memory": None,
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# Mock the _fetch_model_config to avoid database calls
|
||||
def mock_fetch_model_config(**_kwargs):
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
# Create mock model instance
|
||||
mock_model_instance = MagicMock()
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
mock_message = AssistantPromptMessage(content="Test response from mock")
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create mock model config
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.parameters = {}
|
||||
|
||||
return mock_model_instance, mock_model_config
|
||||
|
||||
# Mock fetch_prompt_messages to avoid database calls
|
||||
def mock_fetch_prompt_messages_1(**_kwargs):
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
|
||||
return [
|
||||
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
|
||||
UserPromptMessage(content="what's the weather today?"),
|
||||
], []
|
||||
|
||||
with (
|
||||
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
|
||||
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
|
||||
):
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for item in result:
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
if item.node_run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
print(f"Error: {item.node_run_result.error}")
|
||||
print(f"Error type: {item.node_run_result.error_type}")
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.process_data is not None
|
||||
assert item.node_run_result.outputs is not None
|
||||
assert item.node_run_result.outputs.get("text") is not None
|
||||
assert item.node_run_result.outputs.get("usage", {})["total_tokens"] > 0
|
||||
|
||||
|
||||
def test_execute_llm_with_jinja2():
|
||||
"""
|
||||
Test execute LLM node with jinja2
|
||||
"""
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
|
||||
"prompt_config": {
|
||||
"jinja2_variables": [
|
||||
{"variable": "sys_query", "value_selector": ["sys", "query"]},
|
||||
{"variable": "output", "value_selector": ["abc", "output"]},
|
||||
]
|
||||
},
|
||||
"prompt_template": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
|
||||
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
|
||||
"edition_type": "jinja2",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"text": "{{#sys.query#}}",
|
||||
"jinja2_text": "{{sys_query}}",
|
||||
"edition_type": "basic",
|
||||
},
|
||||
],
|
||||
"memory": None,
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# Mock the _fetch_model_config method
|
||||
def mock_fetch_model_config(**_kwargs):
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
# Create mock model instance
|
||||
mock_model_instance = MagicMock()
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create mock model config
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.parameters = {}
|
||||
|
||||
return mock_model_instance, mock_model_config
|
||||
|
||||
# Mock fetch_prompt_messages to avoid database calls
|
||||
def mock_fetch_prompt_messages_2(**_kwargs):
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
|
||||
return [
|
||||
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
|
||||
UserPromptMessage(content="what's the weather today?"),
|
||||
], []
|
||||
|
||||
with (
|
||||
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
|
||||
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
|
||||
):
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
for item in result:
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.process_data is not None
|
||||
assert "sunny" in json.dumps(item.node_run_result.process_data)
|
||||
assert "what's the weather today?" in json.dumps(item.node_run_result.process_data)
|
||||
|
||||
|
||||
def test_extract_json():
|
||||
llm_texts = [
|
||||
'<think>\n\n</think>{"name": "test", "age": 123', # resoning model (deepseek-r1)
|
||||
'{"name":"test","age":123}', # json schema model (gpt-4o)
|
||||
'{\n "name": "test",\n "age": 123\n}', # small model (llama-3.2-1b)
|
||||
'```json\n{"name": "test", "age": 123}\n```', # json markdown (deepseek-chat)
|
||||
'{"name":"test",age:123}', # without quotes (qwen-2.5-0.5b)
|
||||
]
|
||||
result = {"name": "test", "age": 123}
|
||||
assert all(_parse_structured_output(item) == result for item in llm_texts)
|
||||
@@ -0,0 +1,418 @@
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.model_runtime.entities import AssistantPromptMessage
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
|
||||
|
||||
|
||||
def get_mocked_fetch_memory(memory_text: str):
|
||||
class MemoryMock:
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
):
|
||||
return memory_text
|
||||
|
||||
return MagicMock(return_value=MemoryMock())
|
||||
|
||||
|
||||
def init_parameter_extractor_node(config: dict):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-next-target",
|
||||
"source": "start",
|
||||
"target": "llm",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa"
|
||||
),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["a", "args1"], 1)
|
||||
variable_pool.add(["a", "args2"], 2)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = ParameterExtractorNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
return node
|
||||
|
||||
|
||||
def test_function_calling_parameter_extractor(setup_model_mock):
|
||||
"""
|
||||
Test function calling for parameter extractor.
|
||||
"""
|
||||
node = init_parameter_extractor_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "parameter-extractor",
|
||||
"model": {
|
||||
"provider": "langgenius/openai/openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
"query": ["sys", "query"],
|
||||
"parameters": [{"name": "location", "type": "string", "description": "location", "required": True}],
|
||||
"instruction": "",
|
||||
"reasoning_mode": "function_call",
|
||||
"memory": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs.get("location") == "kawaii"
|
||||
assert result.outputs.get("__reason") == None
|
||||
|
||||
|
||||
def test_instructions(setup_model_mock):
|
||||
"""
|
||||
Test chat parameter extractor.
|
||||
"""
|
||||
node = init_parameter_extractor_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "parameter-extractor",
|
||||
"model": {
|
||||
"provider": "langgenius/openai/openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
"query": ["sys", "query"],
|
||||
"parameters": [{"name": "location", "type": "string", "description": "location", "required": True}],
|
||||
"reasoning_mode": "function_call",
|
||||
"instruction": "{{#sys.query#}}",
|
||||
"memory": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs.get("location") == "kawaii"
|
||||
assert result.outputs.get("__reason") == None
|
||||
|
||||
process_data = result.process_data
|
||||
|
||||
assert process_data is not None
|
||||
process_data.get("prompts")
|
||||
|
||||
for prompt in process_data.get("prompts", []):
|
||||
if prompt.get("role") == "system":
|
||||
assert "what's the weather in SF" in prompt.get("text")
|
||||
|
||||
|
||||
def test_chat_parameter_extractor(setup_model_mock):
|
||||
"""
|
||||
Test chat parameter extractor.
|
||||
"""
|
||||
node = init_parameter_extractor_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "parameter-extractor",
|
||||
"model": {
|
||||
"provider": "langgenius/openai/openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
"query": ["sys", "query"],
|
||||
"parameters": [{"name": "location", "type": "string", "description": "location", "required": True}],
|
||||
"reasoning_mode": "prompt",
|
||||
"instruction": "",
|
||||
"memory": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs.get("location") == ""
|
||||
assert (
|
||||
result.outputs.get("__reason")
|
||||
== "Failed to extract result from function call or text response, using empty result."
|
||||
)
|
||||
assert result.process_data is not None
|
||||
prompts = result.process_data.get("prompts", [])
|
||||
|
||||
for prompt in prompts:
|
||||
if prompt.get("role") == "user":
|
||||
if "<structure>" in prompt.get("text"):
|
||||
assert '<structure>\n{"type": "object"' in prompt.get("text")
|
||||
|
||||
|
||||
def test_completion_parameter_extractor(setup_model_mock):
|
||||
"""
|
||||
Test completion parameter extractor.
|
||||
"""
|
||||
node = init_parameter_extractor_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "parameter-extractor",
|
||||
"model": {
|
||||
"provider": "langgenius/openai/openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {},
|
||||
},
|
||||
"query": ["sys", "query"],
|
||||
"parameters": [{"name": "location", "type": "string", "description": "location", "required": True}],
|
||||
"reasoning_mode": "prompt",
|
||||
"instruction": "{{#sys.query#}}",
|
||||
"memory": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
mode="completion",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs.get("location") == ""
|
||||
assert (
|
||||
result.outputs.get("__reason")
|
||||
== "Failed to extract result from function call or text response, using empty result."
|
||||
)
|
||||
assert result.process_data is not None
|
||||
assert len(result.process_data.get("prompts", [])) == 1
|
||||
assert "SF" in result.process_data.get("prompts", [])[0].get("text")
|
||||
|
||||
|
||||
def test_extract_json_response():
|
||||
"""
|
||||
Test extract json response.
|
||||
"""
|
||||
|
||||
node = init_parameter_extractor_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "parameter-extractor",
|
||||
"model": {
|
||||
"provider": "langgenius/openai/openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {},
|
||||
},
|
||||
"query": ["sys", "query"],
|
||||
"parameters": [{"name": "location", "type": "string", "description": "location", "required": True}],
|
||||
"reasoning_mode": "prompt",
|
||||
"instruction": "{{#sys.query#}}",
|
||||
"memory": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result = node._extract_complete_json_response("""
|
||||
uwu{ovo}
|
||||
{
|
||||
"location": "kawaii"
|
||||
}
|
||||
hello world.
|
||||
""")
|
||||
|
||||
assert result is not None
|
||||
assert result["location"] == "kawaii"
|
||||
|
||||
|
||||
def test_extract_json_from_tool_call():
|
||||
"""
|
||||
Test extract json response.
|
||||
"""
|
||||
|
||||
node = init_parameter_extractor_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "parameter-extractor",
|
||||
"model": {
|
||||
"provider": "langgenius/openai/openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {},
|
||||
},
|
||||
"query": ["sys", "query"],
|
||||
"parameters": [{"name": "location", "type": "string", "description": "location", "required": True}],
|
||||
"reasoning_mode": "prompt",
|
||||
"instruction": "{{#sys.query#}}",
|
||||
"memory": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result = node._extract_json_from_tool_call(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id="llm",
|
||||
type="parameter-extractor",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name="foo", arguments="""{"location":"kawaii"}{"location": 1}"""
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["location"] == "kawaii"
|
||||
|
||||
|
||||
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
|
||||
"""
|
||||
Test chat parameter extractor with memory.
|
||||
"""
|
||||
node = init_parameter_extractor_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "parameter-extractor",
|
||||
"model": {
|
||||
"provider": "langgenius/openai/openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
"query": ["sys", "query"],
|
||||
"parameters": [{"name": "location", "type": "string", "description": "location", "required": True}],
|
||||
"reasoning_mode": "prompt",
|
||||
"instruction": "",
|
||||
"memory": {"window": {"enabled": True, "size": 50}},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
# Test the mock before running the actual test
|
||||
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs.get("location") == ""
|
||||
assert (
|
||||
result.outputs.get("__reason")
|
||||
== "Failed to extract result from function call or text response, using empty result."
|
||||
)
|
||||
assert result.process_data is not None
|
||||
prompts = result.process_data.get("prompts", [])
|
||||
|
||||
latest_role = None
|
||||
for prompt in prompts:
|
||||
if prompt.get("role") == "user":
|
||||
if "<structure>" in prompt.get("text"):
|
||||
assert '<structure>\n{"type": "object"' in prompt.get("text")
|
||||
elif prompt.get("role") == "system":
|
||||
assert "customized memory" in prompt.get("text")
|
||||
|
||||
if latest_role is not None:
|
||||
assert latest_role != prompt.get("role")
|
||||
|
||||
if prompt.get("role") in {"user", "assistant"}:
|
||||
latest_role = prompt.get("role")
|
||||
@@ -0,0 +1,92 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_code(setup_code_executor_mock):
|
||||
code = """{{args2}}"""
|
||||
config = {
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "template-transform",
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "args2"]},
|
||||
],
|
||||
"template": code,
|
||||
},
|
||||
}
|
||||
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-next-target",
|
||||
"source": "start",
|
||||
"target": "1",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["1", "args1"], 1)
|
||||
variable_pool.add(["1", "args2"], 3)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["output"] == "3"
|
||||
130
dify/api/tests/integration_tests/workflow/nodes/test_tool.py
Normal file
130
dify/api/tests/integration_tests/workflow/nodes/test_tool.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
def init_tool_node(config: dict):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-next-target",
|
||||
"source": "start",
|
||||
"target": "1",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = ToolNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
return node
|
||||
|
||||
|
||||
def test_tool_variable_invoke():
|
||||
node = init_tool_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "tool",
|
||||
"title": "a",
|
||||
"desc": "a",
|
||||
"provider_id": "time",
|
||||
"provider_type": "builtin",
|
||||
"provider_name": "time",
|
||||
"tool_name": "current_time",
|
||||
"tool_label": "current_time",
|
||||
"tool_configurations": {},
|
||||
"tool_parameters": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"})
|
||||
|
||||
node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1")
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
for item in result:
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.outputs is not None
|
||||
assert item.node_run_result.outputs.get("text") is not None
|
||||
|
||||
|
||||
def test_tool_mixed_invoke():
|
||||
node = init_tool_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "tool",
|
||||
"title": "a",
|
||||
"desc": "a",
|
||||
"provider_id": "time",
|
||||
"provider_type": "builtin",
|
||||
"provider_name": "time",
|
||||
"tool_name": "current_time",
|
||||
"tool_label": "current_time",
|
||||
"tool_configurations": {
|
||||
"format": "%Y-%m-%d %H:%M:%S",
|
||||
},
|
||||
"tool_parameters": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"})
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
for item in result:
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.outputs is not None
|
||||
assert item.node_run_result.outputs.get("text") is not None
|
||||
Reference in New Issue
Block a user