This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

@@ -0,0 +1,250 @@
#!/usr/bin/env python3
"""
Test suite for async client implementation in the Python SDK.
This test validates the async/await functionality using httpx.AsyncClient
and ensures API parity with sync clients.
"""
import unittest
from unittest.mock import Mock, patch, AsyncMock
from dify_client.async_client import (
AsyncDifyClient,
AsyncChatClient,
AsyncCompletionClient,
AsyncWorkflowClient,
AsyncWorkspaceClient,
AsyncKnowledgeBaseClient,
)
class TestAsyncAPIParity(unittest.TestCase):
"""Test that async clients have API parity with sync clients."""
def test_dify_client_api_parity(self):
"""Test AsyncDifyClient has same methods as DifyClient."""
from dify_client import DifyClient
sync_methods = {name for name in dir(DifyClient) if not name.startswith("_")}
async_methods = {name for name in dir(AsyncDifyClient) if not name.startswith("_")}
# aclose is async-specific, close is sync-specific
sync_methods.discard("close")
async_methods.discard("aclose")
# Verify parity
self.assertEqual(sync_methods, async_methods, "API parity mismatch for DifyClient")
def test_chat_client_api_parity(self):
"""Test AsyncChatClient has same methods as ChatClient."""
from dify_client import ChatClient
sync_methods = {name for name in dir(ChatClient) if not name.startswith("_")}
async_methods = {name for name in dir(AsyncChatClient) if not name.startswith("_")}
sync_methods.discard("close")
async_methods.discard("aclose")
self.assertEqual(sync_methods, async_methods, "API parity mismatch for ChatClient")
def test_completion_client_api_parity(self):
"""Test AsyncCompletionClient has same methods as CompletionClient."""
from dify_client import CompletionClient
sync_methods = {name for name in dir(CompletionClient) if not name.startswith("_")}
async_methods = {name for name in dir(AsyncCompletionClient) if not name.startswith("_")}
sync_methods.discard("close")
async_methods.discard("aclose")
self.assertEqual(sync_methods, async_methods, "API parity mismatch for CompletionClient")
def test_workflow_client_api_parity(self):
"""Test AsyncWorkflowClient has same methods as WorkflowClient."""
from dify_client import WorkflowClient
sync_methods = {name for name in dir(WorkflowClient) if not name.startswith("_")}
async_methods = {name for name in dir(AsyncWorkflowClient) if not name.startswith("_")}
sync_methods.discard("close")
async_methods.discard("aclose")
self.assertEqual(sync_methods, async_methods, "API parity mismatch for WorkflowClient")
def test_workspace_client_api_parity(self):
"""Test AsyncWorkspaceClient has same methods as WorkspaceClient."""
from dify_client import WorkspaceClient
sync_methods = {name for name in dir(WorkspaceClient) if not name.startswith("_")}
async_methods = {name for name in dir(AsyncWorkspaceClient) if not name.startswith("_")}
sync_methods.discard("close")
async_methods.discard("aclose")
self.assertEqual(sync_methods, async_methods, "API parity mismatch for WorkspaceClient")
def test_knowledge_base_client_api_parity(self):
"""Test AsyncKnowledgeBaseClient has same methods as KnowledgeBaseClient."""
from dify_client import KnowledgeBaseClient
sync_methods = {name for name in dir(KnowledgeBaseClient) if not name.startswith("_")}
async_methods = {name for name in dir(AsyncKnowledgeBaseClient) if not name.startswith("_")}
sync_methods.discard("close")
async_methods.discard("aclose")
self.assertEqual(sync_methods, async_methods, "API parity mismatch for KnowledgeBaseClient")
class TestAsyncClientMocked(unittest.IsolatedAsyncioTestCase):
"""Test async client with mocked httpx.AsyncClient."""
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_client_initialization(self, mock_httpx_async_client):
"""Test async client initializes with httpx.AsyncClient."""
mock_client_instance = AsyncMock()
mock_httpx_async_client.return_value = mock_client_instance
client = AsyncDifyClient("test-key", "https://api.dify.ai/v1")
# Verify httpx.AsyncClient was called
mock_httpx_async_client.assert_called_once()
self.assertEqual(client.api_key, "test-key")
await client.aclose()
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_context_manager(self, mock_httpx_async_client):
"""Test async context manager works."""
mock_client_instance = AsyncMock()
mock_httpx_async_client.return_value = mock_client_instance
async with AsyncDifyClient("test-key") as client:
self.assertEqual(client.api_key, "test-key")
# Verify aclose was called
mock_client_instance.aclose.assert_called_once()
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_send_request(self, mock_httpx_async_client):
"""Test async _send_request method."""
mock_response = AsyncMock()
mock_response.json = AsyncMock(return_value={"result": "success"})
mock_response.status_code = 200
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_httpx_async_client.return_value = mock_client_instance
async with AsyncDifyClient("test-key") as client:
response = await client._send_request("GET", "/test")
# Verify request was called
mock_client_instance.request.assert_called_once()
call_args = mock_client_instance.request.call_args
# Verify parameters
self.assertEqual(call_args[0][0], "GET")
self.assertEqual(call_args[0][1], "/test")
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_chat_client(self, mock_httpx_async_client):
"""Test AsyncChatClient functionality."""
mock_response = AsyncMock()
mock_response.text = '{"answer": "Hello!"}'
mock_response.json = AsyncMock(return_value={"answer": "Hello!"})
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_httpx_async_client.return_value = mock_client_instance
async with AsyncChatClient("test-key") as client:
response = await client.create_chat_message({}, "Hi", "user123")
self.assertIn("answer", response.text)
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_completion_client(self, mock_httpx_async_client):
"""Test AsyncCompletionClient functionality."""
mock_response = AsyncMock()
mock_response.text = '{"answer": "Response"}'
mock_response.json = AsyncMock(return_value={"answer": "Response"})
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_httpx_async_client.return_value = mock_client_instance
async with AsyncCompletionClient("test-key") as client:
response = await client.create_completion_message({"query": "test"}, "blocking", "user123")
self.assertIn("answer", response.text)
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_workflow_client(self, mock_httpx_async_client):
"""Test AsyncWorkflowClient functionality."""
mock_response = AsyncMock()
mock_response.json = AsyncMock(return_value={"result": "success"})
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_httpx_async_client.return_value = mock_client_instance
async with AsyncWorkflowClient("test-key") as client:
response = await client.run({"input": "test"}, "blocking", "user123")
data = await response.json()
self.assertEqual(data["result"], "success")
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_workspace_client(self, mock_httpx_async_client):
"""Test AsyncWorkspaceClient functionality."""
mock_response = AsyncMock()
mock_response.json = AsyncMock(return_value={"data": []})
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_httpx_async_client.return_value = mock_client_instance
async with AsyncWorkspaceClient("test-key") as client:
response = await client.get_available_models("llm")
data = await response.json()
self.assertIn("data", data)
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_knowledge_base_client(self, mock_httpx_async_client):
"""Test AsyncKnowledgeBaseClient functionality."""
mock_response = AsyncMock()
mock_response.json = AsyncMock(return_value={"data": [], "total": 0})
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_httpx_async_client.return_value = mock_client_instance
async with AsyncKnowledgeBaseClient("test-key") as client:
response = await client.list_datasets()
data = await response.json()
self.assertIn("data", data)
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_all_async_client_classes(self, mock_httpx_async_client):
"""Test all async client classes work with httpx.AsyncClient."""
mock_client_instance = AsyncMock()
mock_httpx_async_client.return_value = mock_client_instance
clients = [
AsyncDifyClient("key"),
AsyncChatClient("key"),
AsyncCompletionClient("key"),
AsyncWorkflowClient("key"),
AsyncWorkspaceClient("key"),
AsyncKnowledgeBaseClient("key"),
]
# Verify httpx.AsyncClient was called for each
self.assertEqual(mock_httpx_async_client.call_count, 6)
# Clean up
for client in clients:
await client.aclose()
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,489 @@
import os
import time
import unittest
from unittest.mock import Mock, patch, mock_open
from dify_client.client import (
ChatClient,
CompletionClient,
DifyClient,
KnowledgeBaseClient,
)
API_KEY = os.environ.get("API_KEY")
APP_ID = os.environ.get("APP_ID")
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.dify.ai/v1")
FILE_PATH_BASE = os.path.dirname(__file__)
class TestKnowledgeBaseClient(unittest.TestCase):
def setUp(self):
self.api_key = "test-api-key"
self.base_url = "https://api.dify.ai/v1"
self.knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url)
self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md"))
self.dataset_id = "test-dataset-id"
self.document_id = "test-document-id"
self.segment_id = "test-segment-id"
self.batch_id = "test-batch-id"
def _get_dataset_kb_client(self):
return KnowledgeBaseClient(self.api_key, base_url=self.base_url, dataset_id=self.dataset_id)
@patch("dify_client.client.httpx.Client")
def test_001_create_dataset(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.json.return_value = {"id": self.dataset_id, "name": "test_dataset"}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Re-create client with mocked httpx
self.knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url)
response = self.knowledge_base_client.create_dataset(name="test_dataset")
data = response.json()
self.assertIn("id", data)
self.assertEqual("test_dataset", data["name"])
# the following tests require to be executed in order because they use
# the dataset/document/segment ids from the previous test
self._test_002_list_datasets()
self._test_003_create_document_by_text()
self._test_004_update_document_by_text()
self._test_006_update_document_by_file()
self._test_007_list_documents()
self._test_008_delete_document()
self._test_009_create_document_by_file()
self._test_010_add_segments()
self._test_011_query_segments()
self._test_012_update_document_segment()
self._test_013_delete_document_segment()
self._test_014_delete_dataset()
def _test_002_list_datasets(self):
# Mock the response - using the already mocked client from test_001_create_dataset
mock_response = Mock()
mock_response.json.return_value = {"data": [], "total": 0}
mock_response.status_code = 200
self.knowledge_base_client._client.request.return_value = mock_response
response = self.knowledge_base_client.list_datasets()
data = response.json()
self.assertIn("data", data)
self.assertIn("total", data)
def _test_003_create_document_by_text(self):
client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.create_document_by_text("test_document", "test_text")
data = response.json()
self.assertIn("document", data)
def _test_004_update_document_by_text(self):
client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated")
data = response.json()
self.assertIn("document", data)
self.assertIn("batch", data)
def _test_006_update_document_by_file(self):
client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.update_document_by_file(self.document_id, self.README_FILE_PATH)
data = response.json()
self.assertIn("document", data)
self.assertIn("batch", data)
def _test_007_list_documents(self):
client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.json.return_value = {"data": []}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.list_documents()
data = response.json()
self.assertIn("data", data)
def _test_008_delete_document(self):
client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.delete_document(self.document_id)
data = response.json()
self.assertIn("result", data)
self.assertEqual("success", data["result"])
def _test_009_create_document_by_file(self):
client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.create_document_by_file(self.README_FILE_PATH)
data = response.json()
self.assertIn("document", data)
def _test_010_add_segments(self):
client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.json.return_value = {"data": [{"id": self.segment_id, "content": "test text segment 1"}]}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.add_segments(self.document_id, [{"content": "test text segment 1"}])
data = response.json()
self.assertIn("data", data)
self.assertGreater(len(data["data"]), 0)
def _test_011_query_segments(self):
client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.json.return_value = {"data": [{"id": self.segment_id, "content": "test text segment 1"}]}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.query_segments(self.document_id)
data = response.json()
self.assertIn("data", data)
self.assertGreater(len(data["data"]), 0)
def _test_012_update_document_segment(self):
client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.json.return_value = {"data": {"id": self.segment_id, "content": "test text segment 1 updated"}}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.update_document_segment(
self.document_id,
self.segment_id,
{"content": "test text segment 1 updated"},
)
data = response.json()
self.assertIn("data", data)
self.assertEqual("test text segment 1 updated", data["data"]["content"])
def _test_013_delete_document_segment(self):
client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.delete_document_segment(self.document_id, self.segment_id)
data = response.json()
self.assertIn("result", data)
self.assertEqual("success", data["result"])
def _test_014_delete_dataset(self):
client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.status_code = 204
client._client.request.return_value = mock_response
response = client.delete_dataset()
self.assertEqual(204, response.status_code)
class TestChatClient(unittest.TestCase):
@patch("dify_client.client.httpx.Client")
def setUp(self, mock_httpx_client):
self.api_key = "test-api-key"
self.chat_client = ChatClient(self.api_key)
# Set up default mock response for the client
mock_response = Mock()
mock_response.text = '{"answer": "Hello! This is a test response."}'
mock_response.json.return_value = {"answer": "Hello! This is a test response."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
@patch("dify_client.client.httpx.Client")
def test_create_chat_message(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"answer": "Hello! This is a test response."}'
mock_response.json.return_value = {"answer": "Hello! This is a test response."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
chat_client = ChatClient(self.api_key)
response = chat_client.create_chat_message({}, "Hello, World!", "test_user")
self.assertIn("answer", response.text)
@patch("dify_client.client.httpx.Client")
def test_create_chat_message_with_vision_model_by_remote_url(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"answer": "I can see this is a test image description."}'
mock_response.json.return_value = {"answer": "I can see this is a test image description."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
chat_client = ChatClient(self.api_key)
files = [{"type": "image", "transfer_method": "remote_url", "url": "https://example.com/test-image.jpg"}]
response = chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files)
self.assertIn("answer", response.text)
@patch("dify_client.client.httpx.Client")
def test_create_chat_message_with_vision_model_by_local_file(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"answer": "I can see this is a test uploaded image."}'
mock_response.json.return_value = {"answer": "I can see this is a test uploaded image."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
chat_client = ChatClient(self.api_key)
files = [
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": "test-file-id",
}
]
response = chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files)
self.assertIn("answer", response.text)
@patch("dify_client.client.httpx.Client")
def test_get_conversation_messages(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"answer": "Here are the conversation messages."}'
mock_response.json.return_value = {"answer": "Here are the conversation messages."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
chat_client = ChatClient(self.api_key)
response = chat_client.get_conversation_messages("test_user", "test-conversation-id")
self.assertIn("answer", response.text)
@patch("dify_client.client.httpx.Client")
def test_get_conversations(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"data": [{"id": "conv1", "name": "Test Conversation"}]}'
mock_response.json.return_value = {"data": [{"id": "conv1", "name": "Test Conversation"}]}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
chat_client = ChatClient(self.api_key)
response = chat_client.get_conversations("test_user")
self.assertIn("data", response.text)
class TestCompletionClient(unittest.TestCase):
@patch("dify_client.client.httpx.Client")
def setUp(self, mock_httpx_client):
self.api_key = "test-api-key"
self.completion_client = CompletionClient(self.api_key)
# Set up default mock response for the client
mock_response = Mock()
mock_response.text = '{"answer": "This is a test completion response."}'
mock_response.json.return_value = {"answer": "This is a test completion response."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
@patch("dify_client.client.httpx.Client")
def test_create_completion_message(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"answer": "The weather today is sunny with a temperature of 75°F."}'
mock_response.json.return_value = {"answer": "The weather today is sunny with a temperature of 75°F."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
completion_client = CompletionClient(self.api_key)
response = completion_client.create_completion_message(
{"query": "What's the weather like today?"}, "blocking", "test_user"
)
self.assertIn("answer", response.text)
@patch("dify_client.client.httpx.Client")
def test_create_completion_message_with_vision_model_by_remote_url(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"answer": "This is a test image description from completion API."}'
mock_response.json.return_value = {"answer": "This is a test image description from completion API."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
completion_client = CompletionClient(self.api_key)
files = [{"type": "image", "transfer_method": "remote_url", "url": "https://example.com/test-image.jpg"}]
response = completion_client.create_completion_message(
{"query": "Describe the picture."}, "blocking", "test_user", files
)
self.assertIn("answer", response.text)
@patch("dify_client.client.httpx.Client")
def test_create_completion_message_with_vision_model_by_local_file(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"answer": "This is a test uploaded image description from completion API."}'
mock_response.json.return_value = {"answer": "This is a test uploaded image description from completion API."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
completion_client = CompletionClient(self.api_key)
files = [
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": "test-file-id",
}
]
response = completion_client.create_completion_message(
{"query": "Describe the picture."}, "blocking", "test_user", files
)
self.assertIn("answer", response.text)
class TestDifyClient(unittest.TestCase):
@patch("dify_client.client.httpx.Client")
def setUp(self, mock_httpx_client):
self.api_key = "test-api-key"
self.dify_client = DifyClient(self.api_key)
# Set up default mock response for the client
mock_response = Mock()
mock_response.text = '{"result": "success"}'
mock_response.json.return_value = {"result": "success"}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
@patch("dify_client.client.httpx.Client")
def test_message_feedback(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"success": true}'
mock_response.json.return_value = {"success": True}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
dify_client = DifyClient(self.api_key)
response = dify_client.message_feedback("test-message-id", "like", "test_user")
self.assertIn("success", response.text)
@patch("dify_client.client.httpx.Client")
def test_get_application_parameters(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"user_input_form": [{"field": "text", "label": "Input"}]}'
mock_response.json.return_value = {"user_input_form": [{"field": "text", "label": "Input"}]}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
dify_client = DifyClient(self.api_key)
response = dify_client.get_application_parameters("test_user")
self.assertIn("user_input_form", response.text)
@patch("dify_client.client.httpx.Client")
@patch("builtins.open", new_callable=mock_open, read_data=b"fake image data")
def test_file_upload(self, mock_file_open, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"name": "panda.jpeg", "id": "test-file-id"}'
mock_response.json.return_value = {"name": "panda.jpeg", "id": "test-file-id"}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
dify_client = DifyClient(self.api_key)
file_path = "/path/to/test/panda.jpeg"
file_name = "panda.jpeg"
mime_type = "image/jpeg"
with open(file_path, "rb") as file:
files = {"file": (file_name, file, mime_type)}
response = dify_client.file_upload("test_user", files)
self.assertIn("name", response.text)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,79 @@
"""Tests for custom exceptions."""
import unittest
from dify_client.exceptions import (
DifyClientError,
APIError,
AuthenticationError,
RateLimitError,
ValidationError,
NetworkError,
TimeoutError,
FileUploadError,
DatasetError,
WorkflowError,
)
class TestExceptions(unittest.TestCase):
"""Test custom exception classes."""
def test_base_exception(self):
"""Test base DifyClientError."""
error = DifyClientError("Test message", 500, {"error": "details"})
self.assertEqual(str(error), "Test message")
self.assertEqual(error.status_code, 500)
self.assertEqual(error.response, {"error": "details"})
def test_api_error(self):
"""Test APIError."""
error = APIError("API failed", 400)
self.assertEqual(error.status_code, 400)
self.assertEqual(error.message, "API failed")
def test_authentication_error(self):
"""Test AuthenticationError."""
error = AuthenticationError("Invalid API key")
self.assertEqual(str(error), "Invalid API key")
def test_rate_limit_error(self):
"""Test RateLimitError."""
error = RateLimitError("Rate limited", retry_after=60)
self.assertEqual(error.retry_after, 60)
error_default = RateLimitError()
self.assertEqual(error_default.retry_after, None)
def test_validation_error(self):
"""Test ValidationError."""
error = ValidationError("Invalid parameter")
self.assertEqual(str(error), "Invalid parameter")
def test_network_error(self):
"""Test NetworkError."""
error = NetworkError("Connection failed")
self.assertEqual(str(error), "Connection failed")
def test_timeout_error(self):
"""Test TimeoutError."""
error = TimeoutError("Request timed out")
self.assertEqual(str(error), "Request timed out")
def test_file_upload_error(self):
"""Test FileUploadError."""
error = FileUploadError("Upload failed")
self.assertEqual(str(error), "Upload failed")
def test_dataset_error(self):
"""Test DatasetError."""
error = DatasetError("Dataset operation failed")
self.assertEqual(str(error), "Dataset operation failed")
def test_workflow_error(self):
"""Test WorkflowError."""
error = WorkflowError("Workflow failed")
self.assertEqual(str(error), "Workflow failed")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,333 @@
#!/usr/bin/env python3
"""
Test suite for httpx migration in the Python SDK.
This test validates that the migration from requests to httpx maintains
backward compatibility and proper resource management.
"""
import unittest
from unittest.mock import Mock, patch
from dify_client import (
DifyClient,
ChatClient,
CompletionClient,
WorkflowClient,
WorkspaceClient,
KnowledgeBaseClient,
)
class TestHttpxMigrationMocked(unittest.TestCase):
"""Test cases for httpx migration with mocked requests."""
def setUp(self):
"""Set up test fixtures."""
self.api_key = "test-api-key"
self.base_url = "https://api.dify.ai/v1"
@patch("dify_client.client.httpx.Client")
def test_client_initialization(self, mock_httpx_client):
"""Test that client initializes with httpx.Client."""
mock_client_instance = Mock()
mock_httpx_client.return_value = mock_client_instance
client = DifyClient(self.api_key, self.base_url)
# Verify httpx.Client was called with correct parameters
mock_httpx_client.assert_called_once()
call_kwargs = mock_httpx_client.call_args[1]
self.assertEqual(call_kwargs["base_url"], self.base_url)
# Verify client properties
self.assertEqual(client.api_key, self.api_key)
self.assertEqual(client.base_url, self.base_url)
client.close()
@patch("dify_client.client.httpx.Client")
def test_context_manager_support(self, mock_httpx_client):
"""Test that client works as context manager."""
mock_client_instance = Mock()
mock_httpx_client.return_value = mock_client_instance
with DifyClient(self.api_key, self.base_url) as client:
self.assertEqual(client.api_key, self.api_key)
# Verify close was called
mock_client_instance.close.assert_called_once()
@patch("dify_client.client.httpx.Client")
def test_manual_close(self, mock_httpx_client):
"""Test manual close() method."""
mock_client_instance = Mock()
mock_httpx_client.return_value = mock_client_instance
client = DifyClient(self.api_key, self.base_url)
client.close()
# Verify close was called
mock_client_instance.close.assert_called_once()
@patch("dify_client.client.httpx.Client")
def test_send_request_httpx_compatibility(self, mock_httpx_client):
"""Test _send_request uses httpx.Client.request properly."""
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
client = DifyClient(self.api_key, self.base_url)
response = client._send_request("GET", "/test-endpoint")
# Verify httpx.Client.request was called correctly
mock_client_instance.request.assert_called_once()
call_args = mock_client_instance.request.call_args
# Verify method and endpoint
self.assertEqual(call_args[0][0], "GET")
self.assertEqual(call_args[0][1], "/test-endpoint")
# Verify headers contain authorization
headers = call_args[1]["headers"]
self.assertEqual(headers["Authorization"], f"Bearer {self.api_key}")
self.assertEqual(headers["Content-Type"], "application/json")
client.close()
@patch("dify_client.client.httpx.Client")
def test_response_compatibility(self, mock_httpx_client):
"""Test httpx.Response is compatible with requests.Response API."""
mock_response = Mock()
mock_response.json.return_value = {"key": "value"}
mock_response.text = '{"key": "value"}'
mock_response.content = b'{"key": "value"}'
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
client = DifyClient(self.api_key, self.base_url)
response = client._send_request("GET", "/test")
# Verify all common response methods work
self.assertEqual(response.json(), {"key": "value"})
self.assertEqual(response.text, '{"key": "value"}')
self.assertEqual(response.content, b'{"key": "value"}')
self.assertEqual(response.status_code, 200)
self.assertEqual(response.headers["Content-Type"], "application/json")
client.close()
@patch("dify_client.client.httpx.Client")
def test_all_client_classes_use_httpx(self, mock_httpx_client):
"""Test that all client classes properly use httpx."""
mock_client_instance = Mock()
mock_httpx_client.return_value = mock_client_instance
clients = [
DifyClient(self.api_key, self.base_url),
ChatClient(self.api_key, self.base_url),
CompletionClient(self.api_key, self.base_url),
WorkflowClient(self.api_key, self.base_url),
WorkspaceClient(self.api_key, self.base_url),
KnowledgeBaseClient(self.api_key, self.base_url),
]
# Verify httpx.Client was called for each client
self.assertEqual(mock_httpx_client.call_count, 6)
# Clean up
for client in clients:
client.close()
@patch("dify_client.client.httpx.Client")
def test_json_parameter_handling(self, mock_httpx_client):
"""Test that json parameter is passed correctly."""
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_response.status_code = 200 # Add status_code attribute
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
client = DifyClient(self.api_key, self.base_url)
test_data = {"key": "value", "number": 123}
client._send_request("POST", "/test", json=test_data)
# Verify json parameter was passed
call_args = mock_client_instance.request.call_args
self.assertEqual(call_args[1]["json"], test_data)
client.close()
@patch("dify_client.client.httpx.Client")
def test_params_parameter_handling(self, mock_httpx_client):
"""Test that params parameter is passed correctly."""
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_response.status_code = 200 # Add status_code attribute
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
client = DifyClient(self.api_key, self.base_url)
test_params = {"page": 1, "limit": 20}
client._send_request("GET", "/test", params=test_params)
# Verify params parameter was passed
call_args = mock_client_instance.request.call_args
self.assertEqual(call_args[1]["params"], test_params)
client.close()
@patch("dify_client.client.httpx.Client")
def test_inheritance_chain(self, mock_httpx_client):
"""Test that inheritance chain is maintained."""
mock_client_instance = Mock()
mock_httpx_client.return_value = mock_client_instance
# ChatClient inherits from DifyClient
chat_client = ChatClient(self.api_key, self.base_url)
self.assertIsInstance(chat_client, DifyClient)
# CompletionClient inherits from DifyClient
completion_client = CompletionClient(self.api_key, self.base_url)
self.assertIsInstance(completion_client, DifyClient)
# WorkflowClient inherits from DifyClient
workflow_client = WorkflowClient(self.api_key, self.base_url)
self.assertIsInstance(workflow_client, DifyClient)
# Clean up
chat_client.close()
completion_client.close()
workflow_client.close()
@patch("dify_client.client.httpx.Client")
def test_nested_context_managers(self, mock_httpx_client):
"""Test nested context managers work correctly."""
mock_client_instance = Mock()
mock_httpx_client.return_value = mock_client_instance
with DifyClient(self.api_key, self.base_url) as client1:
with ChatClient(self.api_key, self.base_url) as client2:
self.assertEqual(client1.api_key, self.api_key)
self.assertEqual(client2.api_key, self.api_key)
# Both close methods should have been called
self.assertEqual(mock_client_instance.close.call_count, 2)
class TestChatClientHttpx(unittest.TestCase):
"""Test ChatClient specific httpx integration."""
@patch("dify_client.client.httpx.Client")
def test_create_chat_message_httpx(self, mock_httpx_client):
"""Test create_chat_message works with httpx."""
mock_response = Mock()
mock_response.text = '{"answer": "Hello!"}'
mock_response.json.return_value = {"answer": "Hello!"}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
with ChatClient("test-key") as client:
response = client.create_chat_message({}, "Hi", "user123")
self.assertIn("answer", response.text)
self.assertEqual(response.json()["answer"], "Hello!")
class TestCompletionClientHttpx(unittest.TestCase):
"""Test CompletionClient specific httpx integration."""
@patch("dify_client.client.httpx.Client")
def test_create_completion_message_httpx(self, mock_httpx_client):
"""Test create_completion_message works with httpx."""
mock_response = Mock()
mock_response.text = '{"answer": "Response"}'
mock_response.json.return_value = {"answer": "Response"}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
with CompletionClient("test-key") as client:
response = client.create_completion_message({"query": "test"}, "blocking", "user123")
self.assertIn("answer", response.text)
class TestKnowledgeBaseClientHttpx(unittest.TestCase):
"""Test KnowledgeBaseClient specific httpx integration."""
@patch("dify_client.client.httpx.Client")
def test_list_datasets_httpx(self, mock_httpx_client):
"""Test list_datasets works with httpx."""
mock_response = Mock()
mock_response.json.return_value = {"data": [], "total": 0}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
with KnowledgeBaseClient("test-key") as client:
response = client.list_datasets()
data = response.json()
self.assertIn("data", data)
self.assertIn("total", data)
class TestWorkflowClientHttpx(unittest.TestCase):
"""Test WorkflowClient specific httpx integration."""
@patch("dify_client.client.httpx.Client")
def test_run_workflow_httpx(self, mock_httpx_client):
"""Test run workflow works with httpx."""
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
with WorkflowClient("test-key") as client:
response = client.run({"input": "test"}, "blocking", "user123")
self.assertEqual(response.json()["result"], "success")
class TestWorkspaceClientHttpx(unittest.TestCase):
"""Test WorkspaceClient specific httpx integration."""
@patch("dify_client.client.httpx.Client")
def test_get_available_models_httpx(self, mock_httpx_client):
"""Test get_available_models works with httpx."""
mock_response = Mock()
mock_response.json.return_value = {"data": []}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
with WorkspaceClient("test-key") as client:
response = client.get_available_models("llm")
self.assertIn("data", response.json())
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,539 @@
"""Integration tests with proper mocking."""
import unittest
from unittest.mock import Mock, patch, MagicMock
import json
import httpx
from dify_client import (
DifyClient,
ChatClient,
CompletionClient,
WorkflowClient,
KnowledgeBaseClient,
WorkspaceClient,
)
from dify_client.exceptions import (
APIError,
AuthenticationError,
RateLimitError,
ValidationError,
)
class TestDifyClientIntegration(unittest.TestCase):
"""Integration tests for DifyClient with mocked HTTP responses."""
def setUp(self):
self.api_key = "test_api_key"
self.base_url = "https://api.dify.ai/v1"
self.client = DifyClient(api_key=self.api_key, base_url=self.base_url, enable_logging=False)
@patch("httpx.Client.request")
def test_get_app_info_integration(self, mock_request):
"""Test get_app_info integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "app_123",
"name": "Test App",
"description": "A test application",
"mode": "chat",
}
mock_request.return_value = mock_response
response = self.client.get_app_info()
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["id"], "app_123")
self.assertEqual(data["name"], "Test App")
mock_request.assert_called_once_with(
"GET",
"/info",
json=None,
params=None,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
@patch("httpx.Client.request")
def test_get_application_parameters_integration(self, mock_request):
"""Test get_application_parameters integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"opening_statement": "Hello! How can I help you?",
"suggested_questions": ["What is AI?", "How does this work?"],
"speech_to_text": {"enabled": True},
"text_to_speech": {"enabled": False},
}
mock_request.return_value = mock_response
response = self.client.get_application_parameters("user_123")
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["opening_statement"], "Hello! How can I help you?")
self.assertEqual(len(data["suggested_questions"]), 2)
mock_request.assert_called_once_with(
"GET",
"/parameters",
json=None,
params={"user": "user_123"},
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
@patch("httpx.Client.request")
def test_file_upload_integration(self, mock_request):
"""Test file_upload integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "file_123",
"name": "test.txt",
"size": 1024,
"mime_type": "text/plain",
}
mock_request.return_value = mock_response
files = {"file": ("test.txt", "test content", "text/plain")}
response = self.client.file_upload("user_123", files)
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["id"], "file_123")
self.assertEqual(data["name"], "test.txt")
@patch("httpx.Client.request")
def test_message_feedback_integration(self, mock_request):
"""Test message_feedback integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"success": True}
mock_request.return_value = mock_response
response = self.client.message_feedback("msg_123", "like", "user_123")
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertTrue(data["success"])
mock_request.assert_called_once_with(
"POST",
"/messages/msg_123/feedbacks",
json={"rating": "like", "user": "user_123"},
params=None,
headers={
"Authorization": "Bearer test_api_key",
"Content-Type": "application/json",
},
)
class TestChatClientIntegration(unittest.TestCase):
"""Integration tests for ChatClient."""
def setUp(self):
self.client = ChatClient("test_api_key", enable_logging=False)
@patch("httpx.Client.request")
def test_create_chat_message_blocking(self, mock_request):
"""Test create_chat_message with blocking response."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "msg_123",
"answer": "Hello! How can I help you today?",
"conversation_id": "conv_123",
"created_at": 1234567890,
}
mock_request.return_value = mock_response
response = self.client.create_chat_message(
inputs={"query": "Hello"},
query="Hello, AI!",
user="user_123",
response_mode="blocking",
)
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["answer"], "Hello! How can I help you today?")
self.assertEqual(data["conversation_id"], "conv_123")
@patch("httpx.Client.request")
def test_create_chat_message_streaming(self, mock_request):
"""Test create_chat_message with streaming response."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.iter_lines.return_value = [
b'data: {"answer": "Hello"}',
b'data: {"answer": " world"}',
b'data: {"answer": "!"}',
]
mock_request.return_value = mock_response
response = self.client.create_chat_message(inputs={}, query="Hello", user="user_123", response_mode="streaming")
self.assertEqual(response.status_code, 200)
lines = list(response.iter_lines())
self.assertEqual(len(lines), 3)
@patch("httpx.Client.request")
def test_get_conversations_integration(self, mock_request):
"""Test get_conversations integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [
{"id": "conv_1", "name": "Conversation 1"},
{"id": "conv_2", "name": "Conversation 2"},
],
"has_more": False,
"limit": 20,
}
mock_request.return_value = mock_response
response = self.client.get_conversations("user_123", limit=20)
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(len(data["data"]), 2)
self.assertEqual(data["data"][0]["name"], "Conversation 1")
@patch("httpx.Client.request")
def test_get_conversation_messages_integration(self, mock_request):
"""Test get_conversation_messages integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [
{"id": "msg_1", "role": "user", "content": "Hello"},
{"id": "msg_2", "role": "assistant", "content": "Hi there!"},
]
}
mock_request.return_value = mock_response
response = self.client.get_conversation_messages("user_123", conversation_id="conv_123")
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(len(data["data"]), 2)
self.assertEqual(data["data"][0]["role"], "user")
class TestCompletionClientIntegration(unittest.TestCase):
"""Integration tests for CompletionClient."""
def setUp(self):
self.client = CompletionClient("test_api_key", enable_logging=False)
@patch("httpx.Client.request")
def test_create_completion_message_blocking(self, mock_request):
"""Test create_completion_message with blocking response."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "comp_123",
"answer": "This is a completion response.",
"created_at": 1234567890,
}
mock_request.return_value = mock_response
response = self.client.create_completion_message(
inputs={"prompt": "Complete this sentence"},
response_mode="blocking",
user="user_123",
)
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["answer"], "This is a completion response.")
@patch("httpx.Client.request")
def test_create_completion_message_with_files(self, mock_request):
"""Test create_completion_message with files."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "comp_124",
"answer": "I can see the image shows...",
"files": [{"id": "file_1", "type": "image"}],
}
mock_request.return_value = mock_response
files = {
"file": {
"type": "image",
"transfer_method": "remote_url",
"url": "https://example.com/image.jpg",
}
}
response = self.client.create_completion_message(
inputs={"prompt": "Describe this image"},
response_mode="blocking",
user="user_123",
files=files,
)
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertIn("image", data["answer"])
self.assertEqual(len(data["files"]), 1)
class TestWorkflowClientIntegration(unittest.TestCase):
"""Integration tests for WorkflowClient."""
def setUp(self):
self.client = WorkflowClient("test_api_key", enable_logging=False)
@patch("httpx.Client.request")
def test_run_workflow_blocking(self, mock_request):
"""Test run workflow with blocking response."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "run_123",
"workflow_id": "workflow_123",
"status": "succeeded",
"inputs": {"query": "Test input"},
"outputs": {"result": "Test output"},
"elapsed_time": 2.5,
}
mock_request.return_value = mock_response
response = self.client.run(inputs={"query": "Test input"}, response_mode="blocking", user="user_123")
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["status"], "succeeded")
self.assertEqual(data["outputs"]["result"], "Test output")
@patch("httpx.Client.request")
def test_get_workflow_logs(self, mock_request):
"""Test get_workflow_logs integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"logs": [
{"id": "log_1", "status": "succeeded", "created_at": 1234567890},
{"id": "log_2", "status": "failed", "created_at": 1234567891},
],
"total": 2,
"page": 1,
"limit": 20,
}
mock_request.return_value = mock_response
response = self.client.get_workflow_logs(page=1, limit=20)
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(len(data["logs"]), 2)
self.assertEqual(data["logs"][0]["status"], "succeeded")
class TestKnowledgeBaseClientIntegration(unittest.TestCase):
"""Integration tests for KnowledgeBaseClient."""
def setUp(self):
self.client = KnowledgeBaseClient("test_api_key")
@patch("httpx.Client.request")
def test_create_dataset(self, mock_request):
"""Test create_dataset integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "dataset_123",
"name": "Test Dataset",
"description": "A test dataset",
"created_at": 1234567890,
}
mock_request.return_value = mock_response
response = self.client.create_dataset(name="Test Dataset")
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["name"], "Test Dataset")
self.assertEqual(data["id"], "dataset_123")
@patch("httpx.Client.request")
def test_list_datasets(self, mock_request):
"""Test list_datasets integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [
{"id": "dataset_1", "name": "Dataset 1"},
{"id": "dataset_2", "name": "Dataset 2"},
],
"has_more": False,
"limit": 20,
}
mock_request.return_value = mock_response
response = self.client.list_datasets(page=1, page_size=20)
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(len(data["data"]), 2)
@patch("httpx.Client.request")
def test_create_document_by_text(self, mock_request):
"""Test create_document_by_text integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"document": {
"id": "doc_123",
"name": "Test Document",
"word_count": 100,
"status": "indexing",
}
}
mock_request.return_value = mock_response
# Mock dataset_id
self.client.dataset_id = "dataset_123"
response = self.client.create_document_by_text(name="Test Document", text="This is test document content.")
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["document"]["name"], "Test Document")
self.assertEqual(data["document"]["word_count"], 100)
class TestWorkspaceClientIntegration(unittest.TestCase):
"""Integration tests for WorkspaceClient."""
def setUp(self):
self.client = WorkspaceClient("test_api_key", enable_logging=False)
@patch("httpx.Client.request")
def test_get_available_models(self, mock_request):
"""Test get_available_models integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"models": [
{"id": "gpt-4", "name": "GPT-4", "provider": "openai"},
{"id": "claude-3", "name": "Claude 3", "provider": "anthropic"},
]
}
mock_request.return_value = mock_response
response = self.client.get_available_models("llm")
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(len(data["models"]), 2)
self.assertEqual(data["models"][0]["id"], "gpt-4")
class TestErrorScenariosIntegration(unittest.TestCase):
"""Integration tests for error scenarios."""
def setUp(self):
self.client = DifyClient("test_api_key", enable_logging=False)
@patch("httpx.Client.request")
def test_authentication_error_integration(self, mock_request):
"""Test authentication error in integration."""
mock_response = Mock()
mock_response.status_code = 401
mock_response.json.return_value = {"message": "Invalid API key"}
mock_request.return_value = mock_response
with self.assertRaises(AuthenticationError) as context:
self.client.get_app_info()
self.assertEqual(str(context.exception), "Invalid API key")
self.assertEqual(context.exception.status_code, 401)
@patch("httpx.Client.request")
def test_rate_limit_error_integration(self, mock_request):
"""Test rate limit error in integration."""
mock_response = Mock()
mock_response.status_code = 429
mock_response.json.return_value = {"message": "Rate limit exceeded"}
mock_response.headers = {"Retry-After": "60"}
mock_request.return_value = mock_response
with self.assertRaises(RateLimitError) as context:
self.client.get_app_info()
self.assertEqual(str(context.exception), "Rate limit exceeded")
self.assertEqual(context.exception.retry_after, "60")
@patch("httpx.Client.request")
def test_server_error_with_retry_integration(self, mock_request):
"""Test server error with retry in integration."""
# API errors don't retry by design - only network/timeout errors retry
mock_response_500 = Mock()
mock_response_500.status_code = 500
mock_response_500.json.return_value = {"message": "Internal server error"}
mock_request.return_value = mock_response_500
with patch("time.sleep"): # Skip actual sleep
with self.assertRaises(APIError) as context:
self.client.get_app_info()
self.assertEqual(str(context.exception), "Internal server error")
self.assertEqual(mock_request.call_count, 1)
@patch("httpx.Client.request")
def test_validation_error_integration(self, mock_request):
"""Test validation error in integration."""
mock_response = Mock()
mock_response.status_code = 422
mock_response.json.return_value = {
"message": "Validation failed",
"details": {"field": "query", "error": "required"},
}
mock_request.return_value = mock_response
with self.assertRaises(ValidationError) as context:
self.client.get_app_info()
self.assertEqual(str(context.exception), "Validation failed")
self.assertEqual(context.exception.status_code, 422)
class TestContextManagerIntegration(unittest.TestCase):
"""Integration tests for context manager usage."""
@patch("httpx.Client.close")
@patch("httpx.Client.request")
def test_context_manager_usage(self, mock_request, mock_close):
"""Test context manager properly closes connections."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"id": "app_123", "name": "Test App"}
mock_request.return_value = mock_response
with DifyClient("test_api_key") as client:
response = client.get_app_info()
self.assertEqual(response.status_code, 200)
# Verify close was called
mock_close.assert_called_once()
@patch("httpx.Client.close")
def test_manual_close(self, mock_close):
"""Test manual close method."""
client = DifyClient("test_api_key")
client.close()
mock_close.assert_called_once()
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,640 @@
"""Unit tests for response models."""
import unittest
import json
from datetime import datetime
from dify_client.models import (
BaseResponse,
ErrorResponse,
FileInfo,
MessageResponse,
ConversationResponse,
DatasetResponse,
DocumentResponse,
DocumentSegmentResponse,
WorkflowRunResponse,
ApplicationParametersResponse,
AnnotationResponse,
PaginatedResponse,
ConversationVariableResponse,
FileUploadResponse,
AudioResponse,
SuggestedQuestionsResponse,
AppInfoResponse,
WorkspaceModelsResponse,
HitTestingResponse,
DatasetTagsResponse,
WorkflowLogsResponse,
ModelProviderResponse,
FileInfoResponse,
WorkflowDraftResponse,
ApiTokenResponse,
JobStatusResponse,
DatasetQueryResponse,
DatasetTemplateResponse,
)
class TestResponseModels(unittest.TestCase):
"""Test cases for response model classes."""
def test_base_response(self):
"""Test BaseResponse model."""
response = BaseResponse(success=True, message="Operation successful")
self.assertTrue(response.success)
self.assertEqual(response.message, "Operation successful")
def test_base_response_defaults(self):
"""Test BaseResponse with default values."""
response = BaseResponse(success=True)
self.assertTrue(response.success)
self.assertIsNone(response.message)
def test_error_response(self):
"""Test ErrorResponse model."""
response = ErrorResponse(
success=False,
message="Error occurred",
error_code="VALIDATION_ERROR",
details={"field": "invalid_value"},
)
self.assertFalse(response.success)
self.assertEqual(response.message, "Error occurred")
self.assertEqual(response.error_code, "VALIDATION_ERROR")
self.assertEqual(response.details["field"], "invalid_value")
def test_file_info(self):
"""Test FileInfo model."""
now = datetime.now()
file_info = FileInfo(
id="file_123",
name="test.txt",
size=1024,
mime_type="text/plain",
url="https://example.com/file.txt",
created_at=now,
)
self.assertEqual(file_info.id, "file_123")
self.assertEqual(file_info.name, "test.txt")
self.assertEqual(file_info.size, 1024)
self.assertEqual(file_info.mime_type, "text/plain")
self.assertEqual(file_info.url, "https://example.com/file.txt")
self.assertEqual(file_info.created_at, now)
def test_message_response(self):
"""Test MessageResponse model."""
response = MessageResponse(
success=True,
id="msg_123",
answer="Hello, world!",
conversation_id="conv_123",
created_at=1234567890,
metadata={"model": "gpt-4"},
files=[{"id": "file_1", "type": "image"}],
)
self.assertTrue(response.success)
self.assertEqual(response.id, "msg_123")
self.assertEqual(response.answer, "Hello, world!")
self.assertEqual(response.conversation_id, "conv_123")
self.assertEqual(response.created_at, 1234567890)
self.assertEqual(response.metadata["model"], "gpt-4")
self.assertEqual(response.files[0]["id"], "file_1")
def test_conversation_response(self):
"""Test ConversationResponse model."""
response = ConversationResponse(
success=True,
id="conv_123",
name="Test Conversation",
inputs={"query": "Hello"},
status="active",
created_at=1234567890,
updated_at=1234567891,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "conv_123")
self.assertEqual(response.name, "Test Conversation")
self.assertEqual(response.inputs["query"], "Hello")
self.assertEqual(response.status, "active")
self.assertEqual(response.created_at, 1234567890)
self.assertEqual(response.updated_at, 1234567891)
def test_dataset_response(self):
"""Test DatasetResponse model."""
response = DatasetResponse(
success=True,
id="dataset_123",
name="Test Dataset",
description="A test dataset",
permission="read",
indexing_technique="high_quality",
embedding_model="text-embedding-ada-002",
embedding_model_provider="openai",
retrieval_model={"search_type": "semantic"},
document_count=10,
word_count=5000,
app_count=2,
created_at=1234567890,
updated_at=1234567891,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "dataset_123")
self.assertEqual(response.name, "Test Dataset")
self.assertEqual(response.description, "A test dataset")
self.assertEqual(response.permission, "read")
self.assertEqual(response.indexing_technique, "high_quality")
self.assertEqual(response.embedding_model, "text-embedding-ada-002")
self.assertEqual(response.embedding_model_provider, "openai")
self.assertEqual(response.retrieval_model["search_type"], "semantic")
self.assertEqual(response.document_count, 10)
self.assertEqual(response.word_count, 5000)
self.assertEqual(response.app_count, 2)
def test_document_response(self):
"""Test DocumentResponse model."""
response = DocumentResponse(
success=True,
id="doc_123",
name="test_document.txt",
data_source_type="upload_file",
position=1,
enabled=True,
word_count=1000,
hit_count=5,
doc_form="text_model",
created_at=1234567890.0,
indexing_status="completed",
completed_at=1234567891.0,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "doc_123")
self.assertEqual(response.name, "test_document.txt")
self.assertEqual(response.data_source_type, "upload_file")
self.assertEqual(response.position, 1)
self.assertTrue(response.enabled)
self.assertEqual(response.word_count, 1000)
self.assertEqual(response.hit_count, 5)
self.assertEqual(response.doc_form, "text_model")
self.assertEqual(response.created_at, 1234567890.0)
self.assertEqual(response.indexing_status, "completed")
self.assertEqual(response.completed_at, 1234567891.0)
def test_document_segment_response(self):
"""Test DocumentSegmentResponse model."""
response = DocumentSegmentResponse(
success=True,
id="seg_123",
position=1,
document_id="doc_123",
content="This is a test segment.",
answer="Test answer",
word_count=5,
tokens=10,
keywords=["test", "segment"],
hit_count=2,
enabled=True,
status="completed",
created_at=1234567890.0,
completed_at=1234567891.0,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "seg_123")
self.assertEqual(response.position, 1)
self.assertEqual(response.document_id, "doc_123")
self.assertEqual(response.content, "This is a test segment.")
self.assertEqual(response.answer, "Test answer")
self.assertEqual(response.word_count, 5)
self.assertEqual(response.tokens, 10)
self.assertEqual(response.keywords, ["test", "segment"])
self.assertEqual(response.hit_count, 2)
self.assertTrue(response.enabled)
self.assertEqual(response.status, "completed")
self.assertEqual(response.created_at, 1234567890.0)
self.assertEqual(response.completed_at, 1234567891.0)
def test_workflow_run_response(self):
"""Test WorkflowRunResponse model."""
response = WorkflowRunResponse(
success=True,
id="run_123",
workflow_id="workflow_123",
status="succeeded",
inputs={"query": "test"},
outputs={"answer": "result"},
elapsed_time=5.5,
total_tokens=100,
total_steps=3,
created_at=1234567890.0,
finished_at=1234567895.5,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "run_123")
self.assertEqual(response.workflow_id, "workflow_123")
self.assertEqual(response.status, "succeeded")
self.assertEqual(response.inputs["query"], "test")
self.assertEqual(response.outputs["answer"], "result")
self.assertEqual(response.elapsed_time, 5.5)
self.assertEqual(response.total_tokens, 100)
self.assertEqual(response.total_steps, 3)
self.assertEqual(response.created_at, 1234567890.0)
self.assertEqual(response.finished_at, 1234567895.5)
def test_application_parameters_response(self):
"""Test ApplicationParametersResponse model."""
response = ApplicationParametersResponse(
success=True,
opening_statement="Hello! How can I help you?",
suggested_questions=["What is AI?", "How does this work?"],
speech_to_text={"enabled": True},
text_to_speech={"enabled": False, "voice": "alloy"},
retriever_resource={"enabled": True},
sensitive_word_avoidance={"enabled": False},
file_upload={"enabled": True, "file_size_limit": 10485760},
system_parameters={"max_tokens": 1000},
user_input_form=[{"type": "text", "label": "Query"}],
)
self.assertTrue(response.success)
self.assertEqual(response.opening_statement, "Hello! How can I help you?")
self.assertEqual(response.suggested_questions, ["What is AI?", "How does this work?"])
self.assertTrue(response.speech_to_text["enabled"])
self.assertFalse(response.text_to_speech["enabled"])
self.assertEqual(response.text_to_speech["voice"], "alloy")
self.assertTrue(response.retriever_resource["enabled"])
self.assertFalse(response.sensitive_word_avoidance["enabled"])
self.assertTrue(response.file_upload["enabled"])
self.assertEqual(response.file_upload["file_size_limit"], 10485760)
self.assertEqual(response.system_parameters["max_tokens"], 1000)
self.assertEqual(response.user_input_form[0]["type"], "text")
def test_annotation_response(self):
"""Test AnnotationResponse model."""
response = AnnotationResponse(
success=True,
id="annotation_123",
question="What is the capital of France?",
answer="Paris",
content="Additional context",
created_at=1234567890.0,
updated_at=1234567891.0,
created_by="user_123",
updated_by="user_123",
hit_count=5,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "annotation_123")
self.assertEqual(response.question, "What is the capital of France?")
self.assertEqual(response.answer, "Paris")
self.assertEqual(response.content, "Additional context")
self.assertEqual(response.created_at, 1234567890.0)
self.assertEqual(response.updated_at, 1234567891.0)
self.assertEqual(response.created_by, "user_123")
self.assertEqual(response.updated_by, "user_123")
self.assertEqual(response.hit_count, 5)
def test_paginated_response(self):
"""Test PaginatedResponse model."""
response = PaginatedResponse(
success=True,
data=[{"id": 1}, {"id": 2}, {"id": 3}],
has_more=True,
limit=10,
total=100,
page=1,
)
self.assertTrue(response.success)
self.assertEqual(len(response.data), 3)
self.assertEqual(response.data[0]["id"], 1)
self.assertTrue(response.has_more)
self.assertEqual(response.limit, 10)
self.assertEqual(response.total, 100)
self.assertEqual(response.page, 1)
def test_conversation_variable_response(self):
"""Test ConversationVariableResponse model."""
response = ConversationVariableResponse(
success=True,
conversation_id="conv_123",
variables=[
{"id": "var_1", "name": "user_name", "value": "John"},
{"id": "var_2", "name": "preferences", "value": {"theme": "dark"}},
],
)
self.assertTrue(response.success)
self.assertEqual(response.conversation_id, "conv_123")
self.assertEqual(len(response.variables), 2)
self.assertEqual(response.variables[0]["name"], "user_name")
self.assertEqual(response.variables[0]["value"], "John")
self.assertEqual(response.variables[1]["name"], "preferences")
self.assertEqual(response.variables[1]["value"]["theme"], "dark")
def test_file_upload_response(self):
"""Test FileUploadResponse model."""
response = FileUploadResponse(
success=True,
id="file_123",
name="test.txt",
size=1024,
mime_type="text/plain",
url="https://example.com/files/test.txt",
created_at=1234567890.0,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "file_123")
self.assertEqual(response.name, "test.txt")
self.assertEqual(response.size, 1024)
self.assertEqual(response.mime_type, "text/plain")
self.assertEqual(response.url, "https://example.com/files/test.txt")
self.assertEqual(response.created_at, 1234567890.0)
def test_audio_response(self):
"""Test AudioResponse model."""
response = AudioResponse(
success=True,
audio="base64_encoded_audio_data",
audio_url="https://example.com/audio.mp3",
duration=10.5,
sample_rate=44100,
)
self.assertTrue(response.success)
self.assertEqual(response.audio, "base64_encoded_audio_data")
self.assertEqual(response.audio_url, "https://example.com/audio.mp3")
self.assertEqual(response.duration, 10.5)
self.assertEqual(response.sample_rate, 44100)
def test_suggested_questions_response(self):
"""Test SuggestedQuestionsResponse model."""
response = SuggestedQuestionsResponse(
success=True,
message_id="msg_123",
questions=[
"What is machine learning?",
"How does AI work?",
"Can you explain neural networks?",
],
)
self.assertTrue(response.success)
self.assertEqual(response.message_id, "msg_123")
self.assertEqual(len(response.questions), 3)
self.assertEqual(response.questions[0], "What is machine learning?")
def test_app_info_response(self):
"""Test AppInfoResponse model."""
response = AppInfoResponse(
success=True,
id="app_123",
name="Test App",
description="A test application",
icon="🤖",
icon_background="#FF6B6B",
mode="chat",
tags=["AI", "Chat", "Test"],
enable_site=True,
enable_api=True,
api_token="app_token_123",
)
self.assertTrue(response.success)
self.assertEqual(response.id, "app_123")
self.assertEqual(response.name, "Test App")
self.assertEqual(response.description, "A test application")
self.assertEqual(response.icon, "🤖")
self.assertEqual(response.icon_background, "#FF6B6B")
self.assertEqual(response.mode, "chat")
self.assertEqual(response.tags, ["AI", "Chat", "Test"])
self.assertTrue(response.enable_site)
self.assertTrue(response.enable_api)
self.assertEqual(response.api_token, "app_token_123")
def test_workspace_models_response(self):
"""Test WorkspaceModelsResponse model."""
response = WorkspaceModelsResponse(
success=True,
models=[
{"id": "gpt-4", "name": "GPT-4", "provider": "openai"},
{"id": "claude-3", "name": "Claude 3", "provider": "anthropic"},
],
)
self.assertTrue(response.success)
self.assertEqual(len(response.models), 2)
self.assertEqual(response.models[0]["id"], "gpt-4")
self.assertEqual(response.models[0]["name"], "GPT-4")
self.assertEqual(response.models[0]["provider"], "openai")
def test_hit_testing_response(self):
"""Test HitTestingResponse model."""
response = HitTestingResponse(
success=True,
query="What is machine learning?",
records=[
{"content": "Machine learning is a subset of AI...", "score": 0.95},
{"content": "ML algorithms learn from data...", "score": 0.87},
],
)
self.assertTrue(response.success)
self.assertEqual(response.query, "What is machine learning?")
self.assertEqual(len(response.records), 2)
self.assertEqual(response.records[0]["score"], 0.95)
def test_dataset_tags_response(self):
"""Test DatasetTagsResponse model."""
response = DatasetTagsResponse(
success=True,
tags=[
{"id": "tag_1", "name": "Technology", "color": "#FF0000"},
{"id": "tag_2", "name": "Science", "color": "#00FF00"},
],
)
self.assertTrue(response.success)
self.assertEqual(len(response.tags), 2)
self.assertEqual(response.tags[0]["name"], "Technology")
self.assertEqual(response.tags[0]["color"], "#FF0000")
def test_workflow_logs_response(self):
"""Test WorkflowLogsResponse model."""
response = WorkflowLogsResponse(
success=True,
logs=[
{"id": "log_1", "status": "succeeded", "created_at": 1234567890},
{"id": "log_2", "status": "failed", "created_at": 1234567891},
],
total=50,
page=1,
limit=10,
has_more=True,
)
self.assertTrue(response.success)
self.assertEqual(len(response.logs), 2)
self.assertEqual(response.logs[0]["status"], "succeeded")
self.assertEqual(response.total, 50)
self.assertEqual(response.page, 1)
self.assertEqual(response.limit, 10)
self.assertTrue(response.has_more)
def test_model_serialization(self):
"""Test that models can be serialized to JSON."""
response = MessageResponse(
success=True,
id="msg_123",
answer="Hello, world!",
conversation_id="conv_123",
)
# Convert to dict and then to JSON
response_dict = {
"success": response.success,
"id": response.id,
"answer": response.answer,
"conversation_id": response.conversation_id,
}
json_str = json.dumps(response_dict)
parsed = json.loads(json_str)
self.assertTrue(parsed["success"])
self.assertEqual(parsed["id"], "msg_123")
self.assertEqual(parsed["answer"], "Hello, world!")
self.assertEqual(parsed["conversation_id"], "conv_123")
# Tests for new response models
def test_model_provider_response(self):
"""Test ModelProviderResponse model."""
response = ModelProviderResponse(
success=True,
provider_name="openai",
provider_type="llm",
models=[
{"id": "gpt-4", "name": "GPT-4", "max_tokens": 8192},
{"id": "gpt-3.5-turbo", "name": "GPT-3.5 Turbo", "max_tokens": 4096},
],
is_enabled=True,
credentials={"api_key": "sk-..."},
)
self.assertTrue(response.success)
self.assertEqual(response.provider_name, "openai")
self.assertEqual(response.provider_type, "llm")
self.assertEqual(len(response.models), 2)
self.assertEqual(response.models[0]["id"], "gpt-4")
self.assertTrue(response.is_enabled)
self.assertEqual(response.credentials["api_key"], "sk-...")
def test_file_info_response(self):
"""Test FileInfoResponse model."""
response = FileInfoResponse(
success=True,
id="file_123",
name="document.pdf",
size=2048576,
mime_type="application/pdf",
url="https://example.com/files/document.pdf",
created_at=1234567890,
metadata={"pages": 10, "author": "John Doe"},
)
self.assertTrue(response.success)
self.assertEqual(response.id, "file_123")
self.assertEqual(response.name, "document.pdf")
self.assertEqual(response.size, 2048576)
self.assertEqual(response.mime_type, "application/pdf")
self.assertEqual(response.url, "https://example.com/files/document.pdf")
self.assertEqual(response.created_at, 1234567890)
self.assertEqual(response.metadata["pages"], 10)
def test_workflow_draft_response(self):
"""Test WorkflowDraftResponse model."""
response = WorkflowDraftResponse(
success=True,
id="draft_123",
app_id="app_456",
draft_data={"nodes": [], "edges": [], "config": {"name": "Test Workflow"}},
version=1,
created_at=1234567890,
updated_at=1234567891,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "draft_123")
self.assertEqual(response.app_id, "app_456")
self.assertEqual(response.draft_data["config"]["name"], "Test Workflow")
self.assertEqual(response.version, 1)
self.assertEqual(response.created_at, 1234567890)
self.assertEqual(response.updated_at, 1234567891)
def test_api_token_response(self):
"""Test ApiTokenResponse model."""
response = ApiTokenResponse(
success=True,
id="token_123",
name="Production Token",
token="app-xxxxxxxxxxxx",
description="Token for production environment",
created_at=1234567890,
last_used_at=1234567891,
is_active=True,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "token_123")
self.assertEqual(response.name, "Production Token")
self.assertEqual(response.token, "app-xxxxxxxxxxxx")
self.assertEqual(response.description, "Token for production environment")
self.assertEqual(response.created_at, 1234567890)
self.assertEqual(response.last_used_at, 1234567891)
self.assertTrue(response.is_active)
def test_job_status_response(self):
"""Test JobStatusResponse model."""
response = JobStatusResponse(
success=True,
job_id="job_123",
job_status="running",
error_msg=None,
progress=0.75,
created_at=1234567890,
updated_at=1234567891,
)
self.assertTrue(response.success)
self.assertEqual(response.job_id, "job_123")
self.assertEqual(response.job_status, "running")
self.assertIsNone(response.error_msg)
self.assertEqual(response.progress, 0.75)
self.assertEqual(response.created_at, 1234567890)
self.assertEqual(response.updated_at, 1234567891)
def test_dataset_query_response(self):
"""Test DatasetQueryResponse model."""
response = DatasetQueryResponse(
success=True,
query="What is machine learning?",
records=[
{"content": "Machine learning is...", "score": 0.95},
{"content": "ML algorithms...", "score": 0.87},
],
total=2,
search_time=0.123,
retrieval_model={"method": "semantic_search", "top_k": 3},
)
self.assertTrue(response.success)
self.assertEqual(response.query, "What is machine learning?")
self.assertEqual(len(response.records), 2)
self.assertEqual(response.total, 2)
self.assertEqual(response.search_time, 0.123)
self.assertEqual(response.retrieval_model["method"], "semantic_search")
def test_dataset_template_response(self):
"""Test DatasetTemplateResponse model."""
response = DatasetTemplateResponse(
success=True,
template_name="customer_support",
display_name="Customer Support",
description="Template for customer support knowledge base",
category="support",
icon="🎧",
config_schema={"fields": [{"name": "category", "type": "string"}]},
)
self.assertTrue(response.success)
self.assertEqual(response.template_name, "customer_support")
self.assertEqual(response.display_name, "Customer Support")
self.assertEqual(response.description, "Template for customer support knowledge base")
self.assertEqual(response.category, "support")
self.assertEqual(response.icon, "🎧")
self.assertEqual(response.config_schema["fields"][0]["name"], "category")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,313 @@
"""Unit tests for retry mechanism and error handling."""
import unittest
from unittest.mock import Mock, patch, MagicMock
import httpx
from dify_client.client import DifyClient
from dify_client.exceptions import (
APIError,
AuthenticationError,
RateLimitError,
ValidationError,
NetworkError,
TimeoutError,
FileUploadError,
)
class TestRetryMechanism(unittest.TestCase):
"""Test cases for retry mechanism."""
def setUp(self):
self.api_key = "test_api_key"
self.base_url = "https://api.dify.ai/v1"
self.client = DifyClient(
api_key=self.api_key,
base_url=self.base_url,
max_retries=3,
retry_delay=0.1, # Short delay for tests
enable_logging=False,
)
@patch("httpx.Client.request")
def test_successful_request_no_retry(self, mock_request):
"""Test that successful requests don't trigger retries."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b'{"success": true}'
mock_request.return_value = mock_response
response = self.client._send_request("GET", "/test")
self.assertEqual(response, mock_response)
self.assertEqual(mock_request.call_count, 1)
@patch("httpx.Client.request")
@patch("time.sleep")
def test_retry_on_network_error(self, mock_sleep, mock_request):
"""Test retry on network errors."""
# First two calls raise network error, third succeeds
mock_request.side_effect = [
httpx.NetworkError("Connection failed"),
httpx.NetworkError("Connection failed"),
Mock(status_code=200, content=b'{"success": true}'),
]
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b'{"success": true}'
response = self.client._send_request("GET", "/test")
self.assertEqual(response.status_code, 200)
self.assertEqual(mock_request.call_count, 3)
self.assertEqual(mock_sleep.call_count, 2)
@patch("httpx.Client.request")
@patch("time.sleep")
def test_retry_on_timeout_error(self, mock_sleep, mock_request):
"""Test retry on timeout errors."""
mock_request.side_effect = [
httpx.TimeoutException("Request timed out"),
httpx.TimeoutException("Request timed out"),
Mock(status_code=200, content=b'{"success": true}'),
]
response = self.client._send_request("GET", "/test")
self.assertEqual(response.status_code, 200)
self.assertEqual(mock_request.call_count, 3)
self.assertEqual(mock_sleep.call_count, 2)
@patch("httpx.Client.request")
@patch("time.sleep")
def test_max_retries_exceeded(self, mock_sleep, mock_request):
"""Test behavior when max retries are exceeded."""
mock_request.side_effect = httpx.NetworkError("Persistent network error")
with self.assertRaises(NetworkError):
self.client._send_request("GET", "/test")
self.assertEqual(mock_request.call_count, 4) # 1 initial + 3 retries
self.assertEqual(mock_sleep.call_count, 3)
@patch("httpx.Client.request")
def test_no_retry_on_client_error(self, mock_request):
"""Test that client errors (4xx) don't trigger retries."""
mock_response = Mock()
mock_response.status_code = 401
mock_response.json.return_value = {"message": "Unauthorized"}
mock_request.return_value = mock_response
with self.assertRaises(AuthenticationError):
self.client._send_request("GET", "/test")
self.assertEqual(mock_request.call_count, 1)
@patch("httpx.Client.request")
def test_retry_on_server_error(self, mock_request):
"""Test that server errors (5xx) don't retry - they raise APIError immediately."""
mock_response_500 = Mock()
mock_response_500.status_code = 500
mock_response_500.json.return_value = {"message": "Internal server error"}
mock_request.return_value = mock_response_500
with self.assertRaises(APIError) as context:
self.client._send_request("GET", "/test")
self.assertEqual(str(context.exception), "Internal server error")
self.assertEqual(context.exception.status_code, 500)
# Should not retry server errors
self.assertEqual(mock_request.call_count, 1)
@patch("httpx.Client.request")
def test_exponential_backoff(self, mock_request):
"""Test exponential backoff timing."""
mock_request.side_effect = [
httpx.NetworkError("Connection failed"),
httpx.NetworkError("Connection failed"),
httpx.NetworkError("Connection failed"),
httpx.NetworkError("Connection failed"), # All attempts fail
]
with patch("time.sleep") as mock_sleep:
with self.assertRaises(NetworkError):
self.client._send_request("GET", "/test")
# Check exponential backoff: 0.1, 0.2, 0.4
expected_calls = [0.1, 0.2, 0.4]
actual_calls = [call[0][0] for call in mock_sleep.call_args_list]
self.assertEqual(actual_calls, expected_calls)
class TestErrorHandling(unittest.TestCase):
"""Test cases for error handling."""
def setUp(self):
self.client = DifyClient(api_key="test_api_key", enable_logging=False)
@patch("httpx.Client.request")
def test_authentication_error(self, mock_request):
"""Test AuthenticationError handling."""
mock_response = Mock()
mock_response.status_code = 401
mock_response.json.return_value = {"message": "Invalid API key"}
mock_request.return_value = mock_response
with self.assertRaises(AuthenticationError) as context:
self.client._send_request("GET", "/test")
self.assertEqual(str(context.exception), "Invalid API key")
self.assertEqual(context.exception.status_code, 401)
@patch("httpx.Client.request")
def test_rate_limit_error(self, mock_request):
"""Test RateLimitError handling."""
mock_response = Mock()
mock_response.status_code = 429
mock_response.json.return_value = {"message": "Rate limit exceeded"}
mock_response.headers = {"Retry-After": "60"}
mock_request.return_value = mock_response
with self.assertRaises(RateLimitError) as context:
self.client._send_request("GET", "/test")
self.assertEqual(str(context.exception), "Rate limit exceeded")
self.assertEqual(context.exception.retry_after, "60")
@patch("httpx.Client.request")
def test_validation_error(self, mock_request):
"""Test ValidationError handling."""
mock_response = Mock()
mock_response.status_code = 422
mock_response.json.return_value = {"message": "Invalid parameters"}
mock_request.return_value = mock_response
with self.assertRaises(ValidationError) as context:
self.client._send_request("GET", "/test")
self.assertEqual(str(context.exception), "Invalid parameters")
self.assertEqual(context.exception.status_code, 422)
@patch("httpx.Client.request")
def test_api_error(self, mock_request):
"""Test general APIError handling."""
mock_response = Mock()
mock_response.status_code = 500
mock_response.json.return_value = {"message": "Internal server error"}
mock_request.return_value = mock_response
with self.assertRaises(APIError) as context:
self.client._send_request("GET", "/test")
self.assertEqual(str(context.exception), "Internal server error")
self.assertEqual(context.exception.status_code, 500)
@patch("httpx.Client.request")
def test_error_response_without_json(self, mock_request):
"""Test error handling when response doesn't contain valid JSON."""
mock_response = Mock()
mock_response.status_code = 500
mock_response.content = b"Internal Server Error"
mock_response.json.side_effect = ValueError("No JSON object could be decoded")
mock_request.return_value = mock_response
with self.assertRaises(APIError) as context:
self.client._send_request("GET", "/test")
self.assertEqual(str(context.exception), "HTTP 500")
@patch("httpx.Client.request")
def test_file_upload_error(self, mock_request):
"""Test FileUploadError handling."""
mock_response = Mock()
mock_response.status_code = 400
mock_response.json.return_value = {"message": "File upload failed"}
mock_request.return_value = mock_response
with self.assertRaises(FileUploadError) as context:
self.client._send_request_with_files("POST", "/upload", {}, {})
self.assertEqual(str(context.exception), "File upload failed")
self.assertEqual(context.exception.status_code, 400)
class TestParameterValidation(unittest.TestCase):
"""Test cases for parameter validation."""
def setUp(self):
self.client = DifyClient(api_key="test_api_key", enable_logging=False)
def test_empty_string_validation(self):
"""Test validation of empty strings."""
with self.assertRaises(ValidationError):
self.client._validate_params(empty_string="")
def test_whitespace_only_string_validation(self):
"""Test validation of whitespace-only strings."""
with self.assertRaises(ValidationError):
self.client._validate_params(whitespace_string=" ")
def test_long_string_validation(self):
"""Test validation of overly long strings."""
long_string = "a" * 10001 # Exceeds 10000 character limit
with self.assertRaises(ValidationError):
self.client._validate_params(long_string=long_string)
def test_large_list_validation(self):
"""Test validation of overly large lists."""
large_list = list(range(1001)) # Exceeds 1000 item limit
with self.assertRaises(ValidationError):
self.client._validate_params(large_list=large_list)
def test_large_dict_validation(self):
"""Test validation of overly large dictionaries."""
large_dict = {f"key_{i}": i for i in range(101)} # Exceeds 100 item limit
with self.assertRaises(ValidationError):
self.client._validate_params(large_dict=large_dict)
def test_valid_parameters_pass(self):
"""Test that valid parameters pass validation."""
# Should not raise any exception
self.client._validate_params(
valid_string="Hello, World!",
valid_list=[1, 2, 3],
valid_dict={"key": "value"},
none_value=None,
)
def test_message_feedback_validation(self):
"""Test validation in message_feedback method."""
with self.assertRaises(ValidationError):
self.client.message_feedback("msg_id", "invalid_rating", "user")
def test_completion_message_validation(self):
"""Test validation in create_completion_message method."""
from dify_client.client import CompletionClient
client = CompletionClient("test_api_key")
with self.assertRaises(ValidationError):
client.create_completion_message(
inputs="not_a_dict", # Should be a dict
response_mode="invalid_mode", # Should be 'blocking' or 'streaming'
user="test_user",
)
def test_chat_message_validation(self):
"""Test validation in create_chat_message method."""
from dify_client.client import ChatClient
client = ChatClient("test_api_key")
with self.assertRaises(ValidationError):
client.create_chat_message(
inputs="not_a_dict", # Should be a dict
query="", # Should not be empty
user="test_user",
response_mode="invalid_mode", # Should be 'blocking' or 'streaming'
)
if __name__ == "__main__":
unittest.main()