dify
This commit is contained in:
File diff suppressed because it is too large
Load Diff
381
dify/api/tests/unit_tests/libs/test_cron_compatibility.py
Normal file
381
dify/api/tests/unit_tests/libs/test_cron_compatibility.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
Enhanced cron syntax compatibility tests for croniter backend.
|
||||
|
||||
This test suite mirrors the frontend cron-parser tests to ensure
|
||||
complete compatibility between frontend and backend cron processing.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
import pytz
|
||||
from croniter import CroniterBadCronError
|
||||
|
||||
from libs.schedule_utils import calculate_next_run_at
|
||||
|
||||
|
||||
class TestCronCompatibility(unittest.TestCase):
|
||||
"""Test enhanced cron syntax compatibility with frontend."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment with fixed time."""
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_enhanced_dayofweek_syntax(self):
|
||||
"""Test enhanced day-of-week syntax compatibility."""
|
||||
test_cases = [
|
||||
("0 9 * * 7", 0), # Sunday as 7
|
||||
("0 9 * * 0", 0), # Sunday as 0
|
||||
("0 9 * * MON", 1), # Monday abbreviation
|
||||
("0 9 * * TUE", 2), # Tuesday abbreviation
|
||||
("0 9 * * WED", 3), # Wednesday abbreviation
|
||||
("0 9 * * THU", 4), # Thursday abbreviation
|
||||
("0 9 * * FRI", 5), # Friday abbreviation
|
||||
("0 9 * * SAT", 6), # Saturday abbreviation
|
||||
("0 9 * * SUN", 0), # Sunday abbreviation
|
||||
]
|
||||
|
||||
for expr, expected_weekday in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert next_time is not None
|
||||
assert (next_time.weekday() + 1 if next_time.weekday() < 6 else 0) == expected_weekday
|
||||
assert next_time.hour == 9
|
||||
assert next_time.minute == 0
|
||||
|
||||
def test_enhanced_month_syntax(self):
|
||||
"""Test enhanced month syntax compatibility."""
|
||||
test_cases = [
|
||||
("0 9 1 JAN *", 1), # January abbreviation
|
||||
("0 9 1 FEB *", 2), # February abbreviation
|
||||
("0 9 1 MAR *", 3), # March abbreviation
|
||||
("0 9 1 APR *", 4), # April abbreviation
|
||||
("0 9 1 MAY *", 5), # May abbreviation
|
||||
("0 9 1 JUN *", 6), # June abbreviation
|
||||
("0 9 1 JUL *", 7), # July abbreviation
|
||||
("0 9 1 AUG *", 8), # August abbreviation
|
||||
("0 9 1 SEP *", 9), # September abbreviation
|
||||
("0 9 1 OCT *", 10), # October abbreviation
|
||||
("0 9 1 NOV *", 11), # November abbreviation
|
||||
("0 9 1 DEC *", 12), # December abbreviation
|
||||
]
|
||||
|
||||
for expr, expected_month in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert next_time is not None
|
||||
assert next_time.month == expected_month
|
||||
assert next_time.day == 1
|
||||
assert next_time.hour == 9
|
||||
|
||||
def test_predefined_expressions(self):
|
||||
"""Test predefined cron expressions compatibility."""
|
||||
test_cases = [
|
||||
("@yearly", lambda dt: dt.month == 1 and dt.day == 1 and dt.hour == 0),
|
||||
("@annually", lambda dt: dt.month == 1 and dt.day == 1 and dt.hour == 0),
|
||||
("@monthly", lambda dt: dt.day == 1 and dt.hour == 0),
|
||||
("@weekly", lambda dt: dt.weekday() == 6 and dt.hour == 0), # Sunday = 6 in weekday()
|
||||
("@daily", lambda dt: dt.hour == 0 and dt.minute == 0),
|
||||
("@midnight", lambda dt: dt.hour == 0 and dt.minute == 0),
|
||||
("@hourly", lambda dt: dt.minute == 0),
|
||||
]
|
||||
|
||||
for expr, validator in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert next_time is not None
|
||||
assert validator(next_time), f"Validator failed for {expr}: {next_time}"
|
||||
|
||||
def test_special_characters(self):
|
||||
"""Test special characters in cron expressions."""
|
||||
test_cases = [
|
||||
"0 9 ? * 1", # ? wildcard
|
||||
"0 12 * * 7", # Sunday as 7
|
||||
"0 15 L * *", # Last day of month
|
||||
]
|
||||
|
||||
for expr in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
try:
|
||||
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert next_time is not None
|
||||
assert next_time > self.base_time
|
||||
except Exception as e:
|
||||
self.fail(f"Expression '{expr}' should be valid but raised: {e}")
|
||||
|
||||
def test_range_and_list_syntax(self):
|
||||
"""Test range and list syntax with abbreviations."""
|
||||
test_cases = [
|
||||
"0 9 * * MON-FRI", # Weekday range with abbreviations
|
||||
"0 9 * JAN-MAR *", # Month range with abbreviations
|
||||
"0 9 * * SUN,WED,FRI", # Weekday list with abbreviations
|
||||
"0 9 1 JAN,JUN,DEC *", # Month list with abbreviations
|
||||
]
|
||||
|
||||
for expr in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
try:
|
||||
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert next_time is not None
|
||||
assert next_time > self.base_time
|
||||
except Exception as e:
|
||||
self.fail(f"Expression '{expr}' should be valid but raised: {e}")
|
||||
|
||||
def test_invalid_enhanced_syntax(self):
|
||||
"""Test that invalid enhanced syntax is properly rejected."""
|
||||
invalid_expressions = [
|
||||
"0 12 * JANUARY *", # Full month name (not supported)
|
||||
"0 12 * * MONDAY", # Full day name (not supported)
|
||||
"0 12 32 JAN *", # Invalid day with valid month
|
||||
"15 10 1 * 8", # Invalid day of week
|
||||
"15 10 1 INVALID *", # Invalid month abbreviation
|
||||
"15 10 1 * INVALID", # Invalid day abbreviation
|
||||
"@invalid", # Invalid predefined expression
|
||||
]
|
||||
|
||||
for expr in invalid_expressions:
|
||||
with self.subTest(expr=expr):
|
||||
with pytest.raises((CroniterBadCronError, ValueError)):
|
||||
calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
|
||||
def test_edge_cases_with_enhanced_syntax(self):
|
||||
"""Test edge cases with enhanced syntax."""
|
||||
test_cases = [
|
||||
("0 0 29 FEB *", lambda dt: dt.month == 2 and dt.day == 29), # Feb 29 with month abbreviation
|
||||
]
|
||||
|
||||
for expr, validator in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
try:
|
||||
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
if next_time: # Some combinations might not occur soon
|
||||
assert validator(next_time), f"Validator failed for {expr}: {next_time}"
|
||||
except (CroniterBadCronError, ValueError):
|
||||
# Some edge cases might be valid but not have upcoming occurrences
|
||||
pass
|
||||
|
||||
# Test complex expressions that have specific constraints
|
||||
complex_expr = "59 23 31 DEC SAT" # December 31st at 23:59 on Saturday
|
||||
try:
|
||||
next_time = calculate_next_run_at(complex_expr, "UTC", self.base_time)
|
||||
if next_time:
|
||||
# The next occurrence might not be exactly Dec 31 if it's not a Saturday
|
||||
# Just verify it's a valid result
|
||||
assert next_time is not None
|
||||
assert next_time.hour == 23
|
||||
assert next_time.minute == 59
|
||||
except Exception:
|
||||
# Complex date constraints might not have near-future occurrences
|
||||
pass
|
||||
|
||||
|
||||
class TestTimezoneCompatibility(unittest.TestCase):
|
||||
"""Test timezone compatibility between frontend and backend."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_timezone_consistency(self):
|
||||
"""Test that calculations are consistent across different timezones."""
|
||||
timezones = [
|
||||
"UTC",
|
||||
"America/New_York",
|
||||
"Europe/London",
|
||||
"Asia/Tokyo",
|
||||
"Asia/Kolkata",
|
||||
"Australia/Sydney",
|
||||
]
|
||||
|
||||
expression = "0 12 * * *" # Daily at noon
|
||||
|
||||
for timezone in timezones:
|
||||
with self.subTest(timezone=timezone):
|
||||
next_time = calculate_next_run_at(expression, timezone, self.base_time)
|
||||
assert next_time is not None
|
||||
|
||||
# Convert back to the target timezone to verify it's noon
|
||||
tz = pytz.timezone(timezone)
|
||||
local_time = next_time.astimezone(tz)
|
||||
assert local_time.hour == 12
|
||||
assert local_time.minute == 0
|
||||
|
||||
def test_dst_handling(self):
|
||||
"""Test DST boundary handling."""
|
||||
# Test around DST spring forward (March 2024)
|
||||
dst_base = datetime(2024, 3, 8, 10, 0, 0, tzinfo=UTC)
|
||||
expression = "0 2 * * *" # 2 AM daily (problematic during DST)
|
||||
timezone = "America/New_York"
|
||||
|
||||
try:
|
||||
next_time = calculate_next_run_at(expression, timezone, dst_base)
|
||||
assert next_time is not None
|
||||
|
||||
# During DST spring forward, 2 AM becomes 3 AM - both are acceptable
|
||||
tz = pytz.timezone(timezone)
|
||||
local_time = next_time.astimezone(tz)
|
||||
assert local_time.hour in [2, 3] # Either 2 AM or 3 AM is acceptable
|
||||
except Exception as e:
|
||||
self.fail(f"DST handling failed: {e}")
|
||||
|
||||
def test_half_hour_timezones(self):
|
||||
"""Test timezones with half-hour offsets."""
|
||||
timezones_with_offsets = [
|
||||
("Asia/Kolkata", 17, 30), # UTC+5:30 -> 12:00 UTC = 17:30 IST
|
||||
("Australia/Adelaide", 22, 30), # UTC+10:30 -> 12:00 UTC = 22:30 ACDT (summer time)
|
||||
]
|
||||
|
||||
expression = "0 12 * * *" # Noon UTC
|
||||
|
||||
for timezone, expected_hour, expected_minute in timezones_with_offsets:
|
||||
with self.subTest(timezone=timezone):
|
||||
try:
|
||||
next_time = calculate_next_run_at(expression, timezone, self.base_time)
|
||||
assert next_time is not None
|
||||
|
||||
tz = pytz.timezone(timezone)
|
||||
local_time = next_time.astimezone(tz)
|
||||
assert local_time.hour == expected_hour
|
||||
assert local_time.minute == expected_minute
|
||||
except Exception:
|
||||
# Some complex timezone calculations might vary
|
||||
pass
|
||||
|
||||
def test_invalid_timezone_handling(self):
|
||||
"""Test handling of invalid timezones."""
|
||||
expression = "0 12 * * *"
|
||||
invalid_timezone = "Invalid/Timezone"
|
||||
|
||||
with pytest.raises((ValueError, Exception)): # Should raise an exception
|
||||
calculate_next_run_at(expression, invalid_timezone, self.base_time)
|
||||
|
||||
|
||||
class TestFrontendBackendIntegration(unittest.TestCase):
|
||||
"""Test integration patterns that mirror frontend usage."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_execution_time_calculator_pattern(self):
|
||||
"""Test the pattern used by execution-time-calculator.ts."""
|
||||
# This mirrors the exact usage from execution-time-calculator.ts:47
|
||||
test_data = {
|
||||
"cron_expression": "30 14 * * 1-5", # 2:30 PM weekdays
|
||||
"timezone": "America/New_York",
|
||||
}
|
||||
|
||||
# Get next 5 execution times (like the frontend does)
|
||||
execution_times = []
|
||||
current_base = self.base_time
|
||||
|
||||
for _ in range(5):
|
||||
next_time = calculate_next_run_at(test_data["cron_expression"], test_data["timezone"], current_base)
|
||||
assert next_time is not None
|
||||
execution_times.append(next_time)
|
||||
current_base = next_time + timedelta(seconds=1) # Move slightly forward
|
||||
|
||||
assert len(execution_times) == 5
|
||||
|
||||
# Validate each execution time
|
||||
for exec_time in execution_times:
|
||||
# Convert to local timezone
|
||||
tz = pytz.timezone(test_data["timezone"])
|
||||
local_time = exec_time.astimezone(tz)
|
||||
|
||||
# Should be weekdays (1-5)
|
||||
assert local_time.weekday() in [0, 1, 2, 3, 4] # Mon-Fri in Python weekday
|
||||
|
||||
# Should be 2:30 PM in local time
|
||||
assert local_time.hour == 14
|
||||
assert local_time.minute == 30
|
||||
assert local_time.second == 0
|
||||
|
||||
def test_schedule_service_integration(self):
|
||||
"""Test integration with ScheduleService patterns."""
|
||||
from core.workflow.nodes.trigger_schedule.entities import VisualConfig
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
|
||||
# Test enhanced syntax through visual config conversion
|
||||
visual_configs = [
|
||||
# Test with month abbreviations
|
||||
{
|
||||
"frequency": "monthly",
|
||||
"config": VisualConfig(time="9:00 AM", monthly_days=[1]),
|
||||
"expected_cron": "0 9 1 * *",
|
||||
},
|
||||
# Test with weekday abbreviations
|
||||
{
|
||||
"frequency": "weekly",
|
||||
"config": VisualConfig(time="2:30 PM", weekdays=["mon", "wed", "fri"]),
|
||||
"expected_cron": "30 14 * * 1,3,5",
|
||||
},
|
||||
]
|
||||
|
||||
for test_case in visual_configs:
|
||||
with self.subTest(frequency=test_case["frequency"]):
|
||||
cron_expr = ScheduleService.visual_to_cron(test_case["frequency"], test_case["config"])
|
||||
assert cron_expr == test_case["expected_cron"]
|
||||
|
||||
# Verify the generated cron expression is valid
|
||||
next_time = calculate_next_run_at(cron_expr, "UTC", self.base_time)
|
||||
assert next_time is not None
|
||||
|
||||
def test_error_handling_consistency(self):
|
||||
"""Test that error handling matches frontend expectations."""
|
||||
invalid_expressions = [
|
||||
"60 10 1 * *", # Invalid minute
|
||||
"15 25 1 * *", # Invalid hour
|
||||
"15 10 32 * *", # Invalid day
|
||||
"15 10 1 13 *", # Invalid month
|
||||
"15 10 1", # Too few fields
|
||||
"15 10 1 * * *", # 6 fields (not supported in frontend)
|
||||
"0 15 10 1 * * *", # 7 fields (not supported in frontend)
|
||||
"invalid expression", # Completely invalid
|
||||
]
|
||||
|
||||
for expr in invalid_expressions:
|
||||
with self.subTest(expr=repr(expr)):
|
||||
with pytest.raises((CroniterBadCronError, ValueError, Exception)):
|
||||
calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
|
||||
# Note: Empty/whitespace expressions are not tested here as they are
|
||||
# not expected in normal usage due to database constraints (nullable=False)
|
||||
|
||||
def test_performance_requirements(self):
|
||||
"""Test that complex expressions parse within reasonable time."""
|
||||
import time
|
||||
|
||||
complex_expressions = [
|
||||
"*/5 9-17 * * 1-5", # Every 5 minutes, weekdays, business hours
|
||||
"0 */2 1,15 * *", # Every 2 hours on 1st and 15th
|
||||
"30 14 * * 1,3,5", # Mon, Wed, Fri at 14:30
|
||||
"15,45 8-18 * * 1-5", # 15 and 45 minutes past hour, weekdays
|
||||
"0 9 * JAN-MAR MON-FRI", # Enhanced syntax: Q1 weekdays at 9 AM
|
||||
"0 12 ? * SUN", # Enhanced syntax: Sundays at noon with ?
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for expr in complex_expressions:
|
||||
with self.subTest(expr=expr):
|
||||
try:
|
||||
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert next_time is not None
|
||||
except CroniterBadCronError:
|
||||
# Some enhanced syntax might not be supported, that's OK
|
||||
pass
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = (end_time - start_time) * 1000 # Convert to milliseconds
|
||||
|
||||
# Should complete within reasonable time (less than 150ms like frontend)
|
||||
assert execution_time < 150, "Complex expressions should parse quickly"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Import timedelta for the test
|
||||
from datetime import timedelta
|
||||
|
||||
unittest.main()
|
||||
68
dify/api/tests/unit_tests/libs/test_custom_inputs.py
Normal file
68
dify/api/tests/unit_tests/libs/test_custom_inputs.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Unit tests for custom input types."""
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.custom_inputs import time_duration
|
||||
|
||||
|
||||
class TestTimeDuration:
|
||||
"""Test time_duration input validator."""
|
||||
|
||||
def test_valid_days(self):
|
||||
"""Test valid days format."""
|
||||
result = time_duration("7d")
|
||||
assert result == "7d"
|
||||
|
||||
def test_valid_hours(self):
|
||||
"""Test valid hours format."""
|
||||
result = time_duration("4h")
|
||||
assert result == "4h"
|
||||
|
||||
def test_valid_minutes(self):
|
||||
"""Test valid minutes format."""
|
||||
result = time_duration("30m")
|
||||
assert result == "30m"
|
||||
|
||||
def test_valid_seconds(self):
|
||||
"""Test valid seconds format."""
|
||||
result = time_duration("30s")
|
||||
assert result == "30s"
|
||||
|
||||
def test_uppercase_conversion(self):
|
||||
"""Test uppercase units are converted to lowercase."""
|
||||
result = time_duration("7D")
|
||||
assert result == "7d"
|
||||
|
||||
result = time_duration("4H")
|
||||
assert result == "4h"
|
||||
|
||||
def test_invalid_format_no_unit(self):
|
||||
"""Test invalid format without unit."""
|
||||
with pytest.raises(ValueError, match="Invalid time duration format"):
|
||||
time_duration("7")
|
||||
|
||||
def test_invalid_format_wrong_unit(self):
|
||||
"""Test invalid format with wrong unit."""
|
||||
with pytest.raises(ValueError, match="Invalid time duration format"):
|
||||
time_duration("7days")
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid time duration format"):
|
||||
time_duration("7x")
|
||||
|
||||
def test_invalid_format_no_number(self):
|
||||
"""Test invalid format without number."""
|
||||
with pytest.raises(ValueError, match="Invalid time duration format"):
|
||||
time_duration("d")
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid time duration format"):
|
||||
time_duration("abc")
|
||||
|
||||
def test_empty_string(self):
|
||||
"""Test empty string."""
|
||||
with pytest.raises(ValueError, match="Time duration cannot be empty"):
|
||||
time_duration("")
|
||||
|
||||
def test_none(self):
|
||||
"""Test None value."""
|
||||
with pytest.raises(ValueError, match="Time duration cannot be empty"):
|
||||
time_duration(None)
|
||||
268
dify/api/tests/unit_tests/libs/test_datetime_utils.py
Normal file
268
dify/api/tests/unit_tests/libs/test_datetime_utils.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytz
|
||||
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
|
||||
|
||||
def test_naive_utc_now(monkeypatch: pytest.MonkeyPatch):
|
||||
tz_aware_utc_now = datetime.datetime.now(tz=datetime.UTC)
|
||||
|
||||
def _now_func(tz: datetime.timezone | None) -> datetime.datetime:
|
||||
return tz_aware_utc_now.astimezone(tz)
|
||||
|
||||
monkeypatch.setattr("libs.datetime_utils._now_func", _now_func)
|
||||
|
||||
naive_datetime = naive_utc_now()
|
||||
|
||||
assert naive_datetime.tzinfo is None
|
||||
assert naive_datetime.date() == tz_aware_utc_now.date()
|
||||
naive_time = naive_datetime.time()
|
||||
utc_time = tz_aware_utc_now.time()
|
||||
assert naive_time == utc_time
|
||||
|
||||
|
||||
class TestParseTimeRange:
|
||||
"""Test cases for parse_time_range function."""
|
||||
|
||||
def test_parse_time_range_basic(self):
|
||||
"""Test basic time range parsing."""
|
||||
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "UTC")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start < end
|
||||
assert start.tzinfo == pytz.UTC
|
||||
assert end.tzinfo == pytz.UTC
|
||||
|
||||
def test_parse_time_range_start_only(self):
|
||||
"""Test parsing with only start time."""
|
||||
start, end = parse_time_range("2024-01-01 10:00", None, "UTC")
|
||||
|
||||
assert start is not None
|
||||
assert end is None
|
||||
assert start.tzinfo == pytz.UTC
|
||||
|
||||
def test_parse_time_range_end_only(self):
|
||||
"""Test parsing with only end time."""
|
||||
start, end = parse_time_range(None, "2024-01-01 18:00", "UTC")
|
||||
|
||||
assert start is None
|
||||
assert end is not None
|
||||
assert end.tzinfo == pytz.UTC
|
||||
|
||||
def test_parse_time_range_both_none(self):
|
||||
"""Test parsing with both times None."""
|
||||
start, end = parse_time_range(None, None, "UTC")
|
||||
|
||||
assert start is None
|
||||
assert end is None
|
||||
|
||||
def test_parse_time_range_different_timezones(self):
|
||||
"""Test parsing with different timezones."""
|
||||
# Test with US/Eastern timezone
|
||||
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "US/Eastern")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.tzinfo == pytz.UTC
|
||||
assert end.tzinfo == pytz.UTC
|
||||
# Verify the times are correctly converted to UTC
|
||||
assert start.hour == 15 # 10 AM EST = 3 PM UTC (in January)
|
||||
assert end.hour == 23 # 6 PM EST = 11 PM UTC (in January)
|
||||
|
||||
def test_parse_time_range_invalid_start_format(self):
|
||||
"""Test parsing with invalid start time format."""
|
||||
with pytest.raises(ValueError, match="time data.*does not match format"):
|
||||
parse_time_range("invalid-date", "2024-01-01 18:00", "UTC")
|
||||
|
||||
def test_parse_time_range_invalid_end_format(self):
|
||||
"""Test parsing with invalid end time format."""
|
||||
with pytest.raises(ValueError, match="time data.*does not match format"):
|
||||
parse_time_range("2024-01-01 10:00", "invalid-date", "UTC")
|
||||
|
||||
def test_parse_time_range_invalid_timezone(self):
|
||||
"""Test parsing with invalid timezone."""
|
||||
with pytest.raises(pytz.exceptions.UnknownTimeZoneError):
|
||||
parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "Invalid/Timezone")
|
||||
|
||||
def test_parse_time_range_start_after_end(self):
|
||||
"""Test parsing with start time after end time."""
|
||||
with pytest.raises(ValueError, match="start must be earlier than or equal to end"):
|
||||
parse_time_range("2024-01-01 18:00", "2024-01-01 10:00", "UTC")
|
||||
|
||||
def test_parse_time_range_start_equals_end(self):
|
||||
"""Test parsing with start time equal to end time."""
|
||||
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 10:00", "UTC")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start == end
|
||||
|
||||
def test_parse_time_range_dst_ambiguous_time(self):
|
||||
"""Test parsing during DST ambiguous time (fall back)."""
|
||||
# This test simulates DST fall back where 2:30 AM occurs twice
|
||||
with patch("pytz.timezone") as mock_timezone:
|
||||
# Mock timezone that raises AmbiguousTimeError
|
||||
mock_tz = mock_timezone.return_value
|
||||
|
||||
# Create a mock datetime object for the return value
|
||||
mock_dt = datetime.datetime(2024, 1, 1, 10, 0, 0)
|
||||
mock_utc_dt = mock_dt.replace(tzinfo=pytz.UTC)
|
||||
|
||||
# Create a proper mock for the localized datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_localized_dt = MagicMock()
|
||||
mock_localized_dt.astimezone.return_value = mock_utc_dt
|
||||
|
||||
# Set up side effects: first call raises exception, second call succeeds
|
||||
mock_tz.localize.side_effect = [
|
||||
pytz.AmbiguousTimeError("Ambiguous time"), # First call for start
|
||||
mock_localized_dt, # Second call for start (with is_dst=False)
|
||||
pytz.AmbiguousTimeError("Ambiguous time"), # First call for end
|
||||
mock_localized_dt, # Second call for end (with is_dst=False)
|
||||
]
|
||||
|
||||
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "US/Eastern")
|
||||
|
||||
# Should use is_dst=False for ambiguous times
|
||||
assert mock_tz.localize.call_count == 4 # 2 calls per time (first fails, second succeeds)
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
|
||||
def test_parse_time_range_dst_nonexistent_time(self):
|
||||
"""Test parsing during DST nonexistent time (spring forward)."""
|
||||
with patch("pytz.timezone") as mock_timezone:
|
||||
# Mock timezone that raises NonExistentTimeError
|
||||
mock_tz = mock_timezone.return_value
|
||||
|
||||
# Create a mock datetime object for the return value
|
||||
mock_dt = datetime.datetime(2024, 1, 1, 10, 0, 0)
|
||||
mock_utc_dt = mock_dt.replace(tzinfo=pytz.UTC)
|
||||
|
||||
# Create a proper mock for the localized datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_localized_dt = MagicMock()
|
||||
mock_localized_dt.astimezone.return_value = mock_utc_dt
|
||||
|
||||
# Set up side effects: first call raises exception, second call succeeds
|
||||
mock_tz.localize.side_effect = [
|
||||
pytz.NonExistentTimeError("Non-existent time"), # First call for start
|
||||
mock_localized_dt, # Second call for start (with adjusted time)
|
||||
pytz.NonExistentTimeError("Non-existent time"), # First call for end
|
||||
mock_localized_dt, # Second call for end (with adjusted time)
|
||||
]
|
||||
|
||||
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "US/Eastern")
|
||||
|
||||
# Should adjust time forward by 1 hour for nonexistent times
|
||||
assert mock_tz.localize.call_count == 4 # 2 calls per time (first fails, second succeeds)
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
|
||||
def test_parse_time_range_edge_cases(self):
|
||||
"""Test edge cases for time parsing."""
|
||||
# Test with midnight times
|
||||
start, end = parse_time_range("2024-01-01 00:00", "2024-01-01 23:59", "UTC")
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.hour == 0
|
||||
assert start.minute == 0
|
||||
assert end.hour == 23
|
||||
assert end.minute == 59
|
||||
|
||||
def test_parse_time_range_different_dates(self):
|
||||
"""Test parsing with different dates."""
|
||||
start, end = parse_time_range("2024-01-01 10:00", "2024-01-02 10:00", "UTC")
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.date() != end.date()
|
||||
assert (end - start).days == 1
|
||||
|
||||
def test_parse_time_range_seconds_handling(self):
|
||||
"""Test that seconds are properly set to 0."""
|
||||
start, end = parse_time_range("2024-01-01 10:30", "2024-01-01 18:45", "UTC")
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.second == 0
|
||||
assert end.second == 0
|
||||
|
||||
def test_parse_time_range_timezone_conversion_accuracy(self):
|
||||
"""Test accurate timezone conversion."""
|
||||
# Test with a known timezone conversion
|
||||
start, end = parse_time_range("2024-01-01 12:00", "2024-01-01 12:00", "Asia/Tokyo")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.tzinfo == pytz.UTC
|
||||
assert end.tzinfo == pytz.UTC
|
||||
# Tokyo is UTC+9, so 12:00 JST = 03:00 UTC
|
||||
assert start.hour == 3
|
||||
assert end.hour == 3
|
||||
|
||||
def test_parse_time_range_summer_time(self):
|
||||
"""Test parsing during summer time (DST)."""
|
||||
# Test with US/Eastern during summer (EDT = UTC-4)
|
||||
start, end = parse_time_range("2024-07-01 12:00", "2024-07-01 12:00", "US/Eastern")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.tzinfo == pytz.UTC
|
||||
assert end.tzinfo == pytz.UTC
|
||||
# 12:00 EDT = 16:00 UTC
|
||||
assert start.hour == 16
|
||||
assert end.hour == 16
|
||||
|
||||
def test_parse_time_range_winter_time(self):
|
||||
"""Test parsing during winter time (standard time)."""
|
||||
# Test with US/Eastern during winter (EST = UTC-5)
|
||||
start, end = parse_time_range("2024-01-01 12:00", "2024-01-01 12:00", "US/Eastern")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.tzinfo == pytz.UTC
|
||||
assert end.tzinfo == pytz.UTC
|
||||
# 12:00 EST = 17:00 UTC
|
||||
assert start.hour == 17
|
||||
assert end.hour == 17
|
||||
|
||||
def test_parse_time_range_empty_strings(self):
|
||||
"""Test parsing with empty strings."""
|
||||
# Empty strings are treated as None, so they should not raise errors
|
||||
start, end = parse_time_range("", "2024-01-01 18:00", "UTC")
|
||||
assert start is None
|
||||
assert end is not None
|
||||
|
||||
start, end = parse_time_range("2024-01-01 10:00", "", "UTC")
|
||||
assert start is not None
|
||||
assert end is None
|
||||
|
||||
def test_parse_time_range_malformed_datetime(self):
|
||||
"""Test parsing with malformed datetime strings."""
|
||||
with pytest.raises(ValueError, match="time data.*does not match format"):
|
||||
parse_time_range("2024-13-01 10:00", "2024-01-01 18:00", "UTC")
|
||||
|
||||
with pytest.raises(ValueError, match="time data.*does not match format"):
|
||||
parse_time_range("2024-01-01 10:00", "2024-01-32 18:00", "UTC")
|
||||
|
||||
def test_parse_time_range_very_long_time_range(self):
|
||||
"""Test parsing with very long time range."""
|
||||
start, end = parse_time_range("2020-01-01 00:00", "2030-12-31 23:59", "UTC")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start < end
|
||||
assert (end - start).days > 3000 # More than 8 years
|
||||
|
||||
def test_parse_time_range_negative_timezone(self):
|
||||
"""Test parsing with negative timezone offset."""
|
||||
start, end = parse_time_range("2024-01-01 12:00", "2024-01-01 12:00", "America/New_York")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.tzinfo == pytz.UTC
|
||||
assert end.tzinfo == pytz.UTC
|
||||
21
dify/api/tests/unit_tests/libs/test_email.py
Normal file
21
dify/api/tests/unit_tests/libs/test_email.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
|
||||
from libs.helper import email
|
||||
|
||||
|
||||
def test_email_with_valid_email():
|
||||
assert email("test@example.com") == "test@example.com"
|
||||
assert email("TEST12345@example.com") == "TEST12345@example.com"
|
||||
assert email("test+test@example.com") == "test+test@example.com"
|
||||
assert email("!#$%&'*+-/=?^_{|}~`@example.com") == "!#$%&'*+-/=?^_{|}~`@example.com"
|
||||
|
||||
|
||||
def test_email_with_invalid_email():
|
||||
with pytest.raises(ValueError, match="invalid_email is not a valid email."):
|
||||
email("invalid_email")
|
||||
|
||||
with pytest.raises(ValueError, match="@example.com is not a valid email."):
|
||||
email("@example.com")
|
||||
|
||||
with pytest.raises(ValueError, match="()@example.com is not a valid email."):
|
||||
email("()@example.com")
|
||||
576
dify/api/tests/unit_tests/libs/test_email_i18n.py
Normal file
576
dify/api/tests/unit_tests/libs/test_email_i18n.py
Normal file
@@ -0,0 +1,576 @@
|
||||
"""
|
||||
Unit tests for EmailI18nService
|
||||
|
||||
Tests the email internationalization service with mocked dependencies
|
||||
following Domain-Driven Design principles.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.email_i18n import (
|
||||
EmailI18nConfig,
|
||||
EmailI18nService,
|
||||
EmailLanguage,
|
||||
EmailTemplate,
|
||||
EmailType,
|
||||
FlaskEmailRenderer,
|
||||
FlaskMailSender,
|
||||
create_default_email_config,
|
||||
get_email_i18n_service,
|
||||
)
|
||||
from services.feature_service import BrandingModel
|
||||
|
||||
|
||||
class MockEmailRenderer:
|
||||
"""Mock implementation of EmailRenderer protocol"""
|
||||
|
||||
def __init__(self):
|
||||
self.rendered_templates: list[tuple[str, dict[str, Any]]] = []
|
||||
|
||||
def render_template(self, template_path: str, **context: Any) -> str:
|
||||
"""Mock render_template that returns a formatted string"""
|
||||
self.rendered_templates.append((template_path, context))
|
||||
return f"<html>Rendered {template_path} with {context}</html>"
|
||||
|
||||
|
||||
class MockBrandingService:
|
||||
"""Mock implementation of BrandingService protocol"""
|
||||
|
||||
def __init__(self, enabled: bool = False, application_title: str = "Dify"):
|
||||
self.enabled = enabled
|
||||
self.application_title = application_title
|
||||
|
||||
def get_branding_config(self) -> BrandingModel:
|
||||
"""Return mock branding configuration"""
|
||||
branding_model = MagicMock(spec=BrandingModel)
|
||||
branding_model.enabled = self.enabled
|
||||
branding_model.application_title = self.application_title
|
||||
return branding_model
|
||||
|
||||
|
||||
class MockEmailSender:
|
||||
"""Mock implementation of EmailSender protocol"""
|
||||
|
||||
def __init__(self):
|
||||
self.sent_emails: list[dict[str, str]] = []
|
||||
|
||||
def send_email(self, to: str, subject: str, html_content: str):
|
||||
"""Mock send_email that records sent emails"""
|
||||
self.sent_emails.append(
|
||||
{
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"html_content": html_content,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestEmailI18nService:
|
||||
"""Test cases for EmailI18nService"""
|
||||
|
||||
@pytest.fixture
|
||||
def email_config(self) -> EmailI18nConfig:
|
||||
"""Create test email configuration"""
|
||||
return EmailI18nConfig(
|
||||
templates={
|
||||
EmailType.RESET_PASSWORD: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Reset Your {application_title} Password",
|
||||
template_path="reset_password_en.html",
|
||||
branded_template_path="branded/reset_password_en.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="重置您的 {application_title} 密码",
|
||||
template_path="reset_password_zh.html",
|
||||
branded_template_path="branded/reset_password_zh.html",
|
||||
),
|
||||
},
|
||||
EmailType.INVITE_MEMBER: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Join {application_title} Workspace",
|
||||
template_path="invite_member_en.html",
|
||||
branded_template_path="branded/invite_member_en.html",
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_renderer(self) -> MockEmailRenderer:
|
||||
"""Create mock email renderer"""
|
||||
return MockEmailRenderer()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_branding_service(self) -> MockBrandingService:
|
||||
"""Create mock branding service"""
|
||||
return MockBrandingService()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sender(self) -> MockEmailSender:
|
||||
"""Create mock email sender"""
|
||||
return MockEmailSender()
|
||||
|
||||
@pytest.fixture
|
||||
def email_service(
|
||||
self,
|
||||
email_config: EmailI18nConfig,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_branding_service: MockBrandingService,
|
||||
mock_sender: MockEmailSender,
|
||||
) -> EmailI18nService:
|
||||
"""Create EmailI18nService with mocked dependencies"""
|
||||
return EmailI18nService(
|
||||
config=email_config,
|
||||
renderer=mock_renderer,
|
||||
branding_service=mock_branding_service,
|
||||
sender=mock_sender,
|
||||
)
|
||||
|
||||
def test_send_email_with_english_language(
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
):
|
||||
"""Test sending email with English language"""
|
||||
email_service.send_email(
|
||||
email_type=EmailType.RESET_PASSWORD,
|
||||
language_code="en-US",
|
||||
to="test@example.com",
|
||||
template_context={"reset_link": "https://example.com/reset"},
|
||||
)
|
||||
|
||||
# Verify renderer was called with correct template
|
||||
assert len(mock_renderer.rendered_templates) == 1
|
||||
template_path, context = mock_renderer.rendered_templates[0]
|
||||
assert template_path == "reset_password_en.html"
|
||||
assert context["reset_link"] == "https://example.com/reset"
|
||||
assert context["branding_enabled"] is False
|
||||
assert context["application_title"] == "Dify"
|
||||
|
||||
# Verify email was sent
|
||||
assert len(mock_sender.sent_emails) == 1
|
||||
sent_email = mock_sender.sent_emails[0]
|
||||
assert sent_email["to"] == "test@example.com"
|
||||
assert sent_email["subject"] == "Reset Your Dify Password"
|
||||
assert "reset_password_en.html" in sent_email["html_content"]
|
||||
|
||||
def test_send_email_with_chinese_language(
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_sender: MockEmailSender,
|
||||
):
|
||||
"""Test sending email with Chinese language"""
|
||||
email_service.send_email(
|
||||
email_type=EmailType.RESET_PASSWORD,
|
||||
language_code="zh-Hans",
|
||||
to="test@example.com",
|
||||
template_context={"reset_link": "https://example.com/reset"},
|
||||
)
|
||||
|
||||
# Verify email was sent with Chinese subject
|
||||
assert len(mock_sender.sent_emails) == 1
|
||||
sent_email = mock_sender.sent_emails[0]
|
||||
assert sent_email["subject"] == "重置您的 Dify 密码"
|
||||
|
||||
def test_send_email_with_branding_enabled(
|
||||
self,
|
||||
email_config: EmailI18nConfig,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
):
|
||||
"""Test sending email with branding enabled"""
|
||||
# Create branding service with branding enabled
|
||||
branding_service = MockBrandingService(enabled=True, application_title="MyApp")
|
||||
|
||||
email_service = EmailI18nService(
|
||||
config=email_config,
|
||||
renderer=mock_renderer,
|
||||
branding_service=branding_service,
|
||||
sender=mock_sender,
|
||||
)
|
||||
|
||||
email_service.send_email(
|
||||
email_type=EmailType.RESET_PASSWORD,
|
||||
language_code="en-US",
|
||||
to="test@example.com",
|
||||
)
|
||||
|
||||
# Verify branded template was used
|
||||
assert len(mock_renderer.rendered_templates) == 1
|
||||
template_path, context = mock_renderer.rendered_templates[0]
|
||||
assert template_path == "branded/reset_password_en.html"
|
||||
assert context["branding_enabled"] is True
|
||||
assert context["application_title"] == "MyApp"
|
||||
|
||||
# Verify subject includes custom application title
|
||||
assert len(mock_sender.sent_emails) == 1
|
||||
sent_email = mock_sender.sent_emails[0]
|
||||
assert sent_email["subject"] == "Reset Your MyApp Password"
|
||||
|
||||
def test_send_email_with_language_fallback(
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_sender: MockEmailSender,
|
||||
):
|
||||
"""Test language fallback to English when requested language not available"""
|
||||
# Request invite member in Chinese (not configured)
|
||||
email_service.send_email(
|
||||
email_type=EmailType.INVITE_MEMBER,
|
||||
language_code="zh-Hans",
|
||||
to="test@example.com",
|
||||
)
|
||||
|
||||
# Should fall back to English
|
||||
assert len(mock_sender.sent_emails) == 1
|
||||
sent_email = mock_sender.sent_emails[0]
|
||||
assert sent_email["subject"] == "Join Dify Workspace"
|
||||
|
||||
def test_send_email_with_unknown_language_code(
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_sender: MockEmailSender,
|
||||
):
|
||||
"""Test unknown language code falls back to English"""
|
||||
email_service.send_email(
|
||||
email_type=EmailType.RESET_PASSWORD,
|
||||
language_code="fr-FR", # French not configured
|
||||
to="test@example.com",
|
||||
)
|
||||
|
||||
# Should use English
|
||||
assert len(mock_sender.sent_emails) == 1
|
||||
sent_email = mock_sender.sent_emails[0]
|
||||
assert sent_email["subject"] == "Reset Your Dify Password"
|
||||
|
||||
def test_subject_format_keyerror_fallback_path(
|
||||
self,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
):
|
||||
"""Trigger subject KeyError and cover except branch."""
|
||||
# Config with subject that references an unknown key (no {application_title} to avoid second format)
|
||||
config = EmailI18nConfig(
|
||||
templates={
|
||||
EmailType.INVITE_MEMBER: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Invite: {unknown_placeholder}",
|
||||
template_path="invite_member_en.html",
|
||||
branded_template_path="branded/invite_member_en.html",
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
branding_service = MockBrandingService(enabled=False)
|
||||
service = EmailI18nService(
|
||||
config=config,
|
||||
renderer=mock_renderer,
|
||||
branding_service=branding_service,
|
||||
sender=mock_sender,
|
||||
)
|
||||
|
||||
# Will raise KeyError on subject.format(**full_context), then hit except branch and skip fallback
|
||||
service.send_email(
|
||||
email_type=EmailType.INVITE_MEMBER,
|
||||
language_code="en-US",
|
||||
to="test@example.com",
|
||||
)
|
||||
|
||||
assert len(mock_sender.sent_emails) == 1
|
||||
# Subject is left unformatted due to KeyError fallback path without application_title
|
||||
assert mock_sender.sent_emails[0]["subject"] == "Invite: {unknown_placeholder}"
|
||||
|
||||
def test_send_change_email_old_phase(
|
||||
self,
|
||||
email_config: EmailI18nConfig,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
mock_branding_service: MockBrandingService,
|
||||
):
|
||||
"""Test sending change email for old email verification"""
|
||||
# Add change email templates to config
|
||||
email_config.templates[EmailType.CHANGE_EMAIL_OLD] = {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Verify your current email",
|
||||
template_path="change_email_old_en.html",
|
||||
branded_template_path="branded/change_email_old_en.html",
|
||||
),
|
||||
}
|
||||
|
||||
email_service = EmailI18nService(
|
||||
config=email_config,
|
||||
renderer=mock_renderer,
|
||||
branding_service=mock_branding_service,
|
||||
sender=mock_sender,
|
||||
)
|
||||
|
||||
email_service.send_change_email(
|
||||
language_code="en-US",
|
||||
to="old@example.com",
|
||||
code="123456",
|
||||
phase="old_email",
|
||||
)
|
||||
|
||||
# Verify correct template and context
|
||||
assert len(mock_renderer.rendered_templates) == 1
|
||||
template_path, context = mock_renderer.rendered_templates[0]
|
||||
assert template_path == "change_email_old_en.html"
|
||||
assert context["to"] == "old@example.com"
|
||||
assert context["code"] == "123456"
|
||||
|
||||
def test_send_change_email_new_phase(
|
||||
self,
|
||||
email_config: EmailI18nConfig,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
mock_branding_service: MockBrandingService,
|
||||
):
|
||||
"""Test sending change email for new email verification"""
|
||||
# Add change email templates to config
|
||||
email_config.templates[EmailType.CHANGE_EMAIL_NEW] = {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Verify your new email",
|
||||
template_path="change_email_new_en.html",
|
||||
branded_template_path="branded/change_email_new_en.html",
|
||||
),
|
||||
}
|
||||
|
||||
email_service = EmailI18nService(
|
||||
config=email_config,
|
||||
renderer=mock_renderer,
|
||||
branding_service=mock_branding_service,
|
||||
sender=mock_sender,
|
||||
)
|
||||
|
||||
email_service.send_change_email(
|
||||
language_code="en-US",
|
||||
to="new@example.com",
|
||||
code="654321",
|
||||
phase="new_email",
|
||||
)
|
||||
|
||||
# Verify correct template and context
|
||||
assert len(mock_renderer.rendered_templates) == 1
|
||||
template_path, context = mock_renderer.rendered_templates[0]
|
||||
assert template_path == "change_email_new_en.html"
|
||||
assert context["to"] == "new@example.com"
|
||||
assert context["code"] == "654321"
|
||||
|
||||
def test_send_change_email_invalid_phase(
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
):
|
||||
"""Test sending change email with invalid phase raises error"""
|
||||
with pytest.raises(ValueError, match="Invalid phase: invalid_phase"):
|
||||
email_service.send_change_email(
|
||||
language_code="en-US",
|
||||
to="test@example.com",
|
||||
code="123456",
|
||||
phase="invalid_phase",
|
||||
)
|
||||
|
||||
def test_send_raw_email_single_recipient(
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_sender: MockEmailSender,
|
||||
):
|
||||
"""Test sending raw email to single recipient"""
|
||||
email_service.send_raw_email(
|
||||
to="test@example.com",
|
||||
subject="Test Subject",
|
||||
html_content="<html>Test Content</html>",
|
||||
)
|
||||
|
||||
assert len(mock_sender.sent_emails) == 1
|
||||
sent_email = mock_sender.sent_emails[0]
|
||||
assert sent_email["to"] == "test@example.com"
|
||||
assert sent_email["subject"] == "Test Subject"
|
||||
assert sent_email["html_content"] == "<html>Test Content</html>"
|
||||
|
||||
def test_send_raw_email_multiple_recipients(
|
||||
self,
|
||||
email_service: EmailI18nService,
|
||||
mock_sender: MockEmailSender,
|
||||
):
|
||||
"""Test sending raw email to multiple recipients"""
|
||||
recipients = ["user1@example.com", "user2@example.com", "user3@example.com"]
|
||||
|
||||
email_service.send_raw_email(
|
||||
to=recipients,
|
||||
subject="Test Subject",
|
||||
html_content="<html>Test Content</html>",
|
||||
)
|
||||
|
||||
# Should send individual emails to each recipient
|
||||
assert len(mock_sender.sent_emails) == 3
|
||||
for i, recipient in enumerate(recipients):
|
||||
sent_email = mock_sender.sent_emails[i]
|
||||
assert sent_email["to"] == recipient
|
||||
assert sent_email["subject"] == "Test Subject"
|
||||
assert sent_email["html_content"] == "<html>Test Content</html>"
|
||||
|
||||
def test_get_template_missing_email_type(
|
||||
self,
|
||||
email_config: EmailI18nConfig,
|
||||
):
|
||||
"""Test getting template for missing email type raises error"""
|
||||
with pytest.raises(ValueError, match="No templates configured for email type"):
|
||||
email_config.get_template(EmailType.EMAIL_CODE_LOGIN, EmailLanguage.EN_US)
|
||||
|
||||
def test_get_template_missing_language_and_english(
|
||||
self,
|
||||
email_config: EmailI18nConfig,
|
||||
):
|
||||
"""Test error when neither requested language nor English fallback exists"""
|
||||
# Add template without English fallback
|
||||
email_config.templates[EmailType.EMAIL_CODE_LOGIN] = {
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="Test",
|
||||
template_path="test.html",
|
||||
branded_template_path="branded/test.html",
|
||||
),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="No template found for"):
|
||||
# Request a language that doesn't exist and no English fallback
|
||||
email_config.get_template(EmailType.EMAIL_CODE_LOGIN, EmailLanguage.EN_US)
|
||||
|
||||
def test_subject_templating_with_variables(
|
||||
self,
|
||||
email_config: EmailI18nConfig,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
mock_branding_service: MockBrandingService,
|
||||
):
|
||||
"""Test subject templating with custom variables"""
|
||||
# Add template with variable in subject
|
||||
email_config.templates[EmailType.OWNER_TRANSFER_NEW_NOTIFY] = {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="You are now the owner of {WorkspaceName}",
|
||||
template_path="owner_transfer_en.html",
|
||||
branded_template_path="branded/owner_transfer_en.html",
|
||||
),
|
||||
}
|
||||
|
||||
email_service = EmailI18nService(
|
||||
config=email_config,
|
||||
renderer=mock_renderer,
|
||||
branding_service=mock_branding_service,
|
||||
sender=mock_sender,
|
||||
)
|
||||
|
||||
email_service.send_email(
|
||||
email_type=EmailType.OWNER_TRANSFER_NEW_NOTIFY,
|
||||
language_code="en-US",
|
||||
to="test@example.com",
|
||||
template_context={"WorkspaceName": "My Workspace"},
|
||||
)
|
||||
|
||||
# Verify subject was templated correctly
|
||||
assert len(mock_sender.sent_emails) == 1
|
||||
sent_email = mock_sender.sent_emails[0]
|
||||
assert sent_email["subject"] == "You are now the owner of My Workspace"
|
||||
|
||||
def test_email_language_from_language_code(self):
|
||||
"""Test EmailLanguage.from_language_code method"""
|
||||
assert EmailLanguage.from_language_code("zh-Hans") == EmailLanguage.ZH_HANS
|
||||
assert EmailLanguage.from_language_code("en-US") == EmailLanguage.EN_US
|
||||
assert EmailLanguage.from_language_code("fr-FR") == EmailLanguage.EN_US # Fallback
|
||||
assert EmailLanguage.from_language_code("unknown") == EmailLanguage.EN_US # Fallback
|
||||
|
||||
|
||||
class TestEmailI18nIntegration:
|
||||
"""Integration tests for email i18n components"""
|
||||
|
||||
def test_create_default_email_config(self):
|
||||
"""Test creating default email configuration"""
|
||||
config = create_default_email_config()
|
||||
|
||||
# Verify key email types have at least English template
|
||||
expected_types = [
|
||||
EmailType.RESET_PASSWORD,
|
||||
EmailType.INVITE_MEMBER,
|
||||
EmailType.EMAIL_CODE_LOGIN,
|
||||
EmailType.CHANGE_EMAIL_OLD,
|
||||
EmailType.CHANGE_EMAIL_NEW,
|
||||
EmailType.OWNER_TRANSFER_CONFIRM,
|
||||
EmailType.OWNER_TRANSFER_OLD_NOTIFY,
|
||||
EmailType.OWNER_TRANSFER_NEW_NOTIFY,
|
||||
EmailType.ACCOUNT_DELETION_SUCCESS,
|
||||
EmailType.ACCOUNT_DELETION_VERIFICATION,
|
||||
EmailType.QUEUE_MONITOR_ALERT,
|
||||
EmailType.DOCUMENT_CLEAN_NOTIFY,
|
||||
]
|
||||
|
||||
for email_type in expected_types:
|
||||
assert email_type in config.templates
|
||||
assert EmailLanguage.EN_US in config.templates[email_type]
|
||||
|
||||
# Verify some have Chinese translations
|
||||
assert EmailLanguage.ZH_HANS in config.templates[EmailType.RESET_PASSWORD]
|
||||
assert EmailLanguage.ZH_HANS in config.templates[EmailType.INVITE_MEMBER]
|
||||
|
||||
def test_get_email_i18n_service(self):
|
||||
"""Test getting global email i18n service instance"""
|
||||
service1 = get_email_i18n_service()
|
||||
service2 = get_email_i18n_service()
|
||||
|
||||
# Should return the same instance
|
||||
assert service1 is service2
|
||||
|
||||
def test_flask_email_renderer(self):
|
||||
"""Test FlaskEmailRenderer implementation"""
|
||||
renderer = FlaskEmailRenderer()
|
||||
|
||||
# Should raise TemplateNotFound when template doesn't exist
|
||||
from jinja2.exceptions import TemplateNotFound
|
||||
|
||||
with pytest.raises(TemplateNotFound):
|
||||
renderer.render_template("test.html", foo="bar")
|
||||
|
||||
def test_flask_mail_sender_not_initialized(self):
|
||||
"""Test FlaskMailSender when mail is not initialized"""
|
||||
sender = FlaskMailSender()
|
||||
|
||||
# Mock mail.is_inited() to return False
|
||||
import libs.email_i18n
|
||||
|
||||
original_mail = libs.email_i18n.mail
|
||||
mock_mail = MagicMock()
|
||||
mock_mail.is_inited.return_value = False
|
||||
libs.email_i18n.mail = mock_mail
|
||||
|
||||
try:
|
||||
# Should not send email when mail is not initialized
|
||||
sender.send_email("test@example.com", "Subject", "<html>Content</html>")
|
||||
mock_mail.send.assert_not_called()
|
||||
finally:
|
||||
# Restore original mail
|
||||
libs.email_i18n.mail = original_mail
|
||||
|
||||
def test_flask_mail_sender_initialized(self):
|
||||
"""Test FlaskMailSender when mail is initialized"""
|
||||
sender = FlaskMailSender()
|
||||
|
||||
# Mock mail.is_inited() to return True
|
||||
import libs.email_i18n
|
||||
|
||||
original_mail = libs.email_i18n.mail
|
||||
mock_mail = MagicMock()
|
||||
mock_mail.is_inited.return_value = True
|
||||
libs.email_i18n.mail = mock_mail
|
||||
|
||||
try:
|
||||
# Should send email when mail is initialized
|
||||
sender.send_email("test@example.com", "Subject", "<html>Content</html>")
|
||||
mock_mail.send.assert_called_once_with(
|
||||
to="test@example.com",
|
||||
subject="Subject",
|
||||
html="<html>Content</html>",
|
||||
)
|
||||
finally:
|
||||
# Restore original mail
|
||||
libs.email_i18n.mail = original_mail
|
||||
187
dify/api/tests/unit_tests/libs/test_external_api.py
Normal file
187
dify/api/tests/unit_tests/libs/test_external_api.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from flask import Blueprint, Flask
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, Unauthorized
|
||||
|
||||
from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_CSRF_TOKEN, COOKIE_NAME_REFRESH_TOKEN
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from libs.exception import BaseHTTPException
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
|
||||
def _create_api_app():
|
||||
app = Flask(__name__)
|
||||
bp = Blueprint("t", __name__)
|
||||
api = ExternalApi(bp)
|
||||
|
||||
@api.route("/bad-request")
|
||||
class Bad(Resource):
|
||||
def get(self):
|
||||
raise BadRequest("invalid input")
|
||||
|
||||
@api.route("/unauth")
|
||||
class Unauth(Resource):
|
||||
def get(self):
|
||||
raise Unauthorized("auth required")
|
||||
|
||||
@api.route("/value-error")
|
||||
class ValErr(Resource):
|
||||
def get(self):
|
||||
raise ValueError("boom")
|
||||
|
||||
@api.route("/quota")
|
||||
class Quota(Resource):
|
||||
def get(self):
|
||||
raise AppInvokeQuotaExceededError("quota exceeded")
|
||||
|
||||
@api.route("/general")
|
||||
class Gen(Resource):
|
||||
def get(self):
|
||||
raise RuntimeError("oops")
|
||||
|
||||
# Note: We avoid altering default_mediatype to keep normal error paths
|
||||
|
||||
# Special 400 message rewrite
|
||||
@api.route("/json-empty")
|
||||
class JsonEmpty(Resource):
|
||||
def get(self):
|
||||
e = BadRequest()
|
||||
# Force the specific message the handler rewrites
|
||||
e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
|
||||
raise e
|
||||
|
||||
# 400 mapping payload path
|
||||
@api.route("/param-errors")
|
||||
class ParamErrors(Resource):
|
||||
def get(self):
|
||||
e = BadRequest()
|
||||
# Coerce a mapping description to trigger param error shaping
|
||||
e.description = {"field": "is required"}
|
||||
raise e
|
||||
|
||||
app.register_blueprint(bp, url_prefix="/api")
|
||||
return app
|
||||
|
||||
|
||||
def test_external_api_error_handlers_basic_paths():
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# 400
|
||||
res = client.get("/api/bad-request")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "bad_request"
|
||||
assert data["status"] == 400
|
||||
|
||||
# 401
|
||||
res = client.get("/api/unauth")
|
||||
assert res.status_code == 401
|
||||
assert "WWW-Authenticate" in res.headers
|
||||
|
||||
# 400 ValueError
|
||||
res = client.get("/api/value-error")
|
||||
assert res.status_code == 400
|
||||
assert res.get_json()["code"] == "invalid_param"
|
||||
|
||||
# 500 general
|
||||
res = client.get("/api/general")
|
||||
assert res.status_code == 500
|
||||
assert res.get_json()["status"] == 500
|
||||
|
||||
|
||||
def test_external_api_json_message_and_bad_request_rewrite():
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# JSON empty special rewrite
|
||||
res = client.get("/api/json-empty")
|
||||
assert res.status_code == 400
|
||||
assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty."
|
||||
|
||||
|
||||
def test_external_api_param_mapping_and_quota_and_exc_info_none():
|
||||
# Force exc_info() to return (None,None,None) only during request
|
||||
import libs.external_api as ext
|
||||
|
||||
orig_exc_info = ext.sys.exc_info
|
||||
try:
|
||||
ext.sys.exc_info = lambda: (None, None, None)
|
||||
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# Param errors mapping payload path
|
||||
res = client.get("/api/param-errors")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "invalid_param"
|
||||
assert data["params"] == "field"
|
||||
|
||||
# Quota path — depending on Flask-RESTX internals it may be handled
|
||||
res = client.get("/api/quota")
|
||||
assert res.status_code in (400, 429)
|
||||
finally:
|
||||
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]
|
||||
|
||||
|
||||
def test_unauthorized_and_force_logout_clears_cookies():
|
||||
"""Test that UnauthorizedAndForceLogout error clears auth cookies"""
|
||||
|
||||
class UnauthorizedAndForceLogout(BaseHTTPException):
|
||||
error_code = "unauthorized_and_force_logout"
|
||||
description = "Unauthorized and force logout."
|
||||
code = 401
|
||||
|
||||
app = Flask(__name__)
|
||||
bp = Blueprint("test", __name__)
|
||||
api = ExternalApi(bp)
|
||||
|
||||
@api.route("/force-logout")
|
||||
class ForceLogout(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise UnauthorizedAndForceLogout()
|
||||
|
||||
app.register_blueprint(bp, url_prefix="/api")
|
||||
client = app.test_client()
|
||||
|
||||
# Set cookies first
|
||||
client.set_cookie(COOKIE_NAME_ACCESS_TOKEN, "test_access_token")
|
||||
client.set_cookie(COOKIE_NAME_CSRF_TOKEN, "test_csrf_token")
|
||||
client.set_cookie(COOKIE_NAME_REFRESH_TOKEN, "test_refresh_token")
|
||||
|
||||
# Make request that should trigger cookie clearing
|
||||
res = client.get("/api/force-logout")
|
||||
|
||||
# Verify response
|
||||
assert res.status_code == 401
|
||||
data = res.get_json()
|
||||
assert data["code"] == "unauthorized_and_force_logout"
|
||||
assert data["status"] == 401
|
||||
assert "WWW-Authenticate" in res.headers
|
||||
|
||||
# Verify Set-Cookie headers are present to clear cookies
|
||||
set_cookie_headers = res.headers.getlist("Set-Cookie")
|
||||
assert len(set_cookie_headers) == 3, f"Expected 3 Set-Cookie headers, got {len(set_cookie_headers)}"
|
||||
|
||||
# Verify each cookie is being cleared (empty value and expired)
|
||||
cookie_names_found = set()
|
||||
for cookie_header in set_cookie_headers:
|
||||
# Check for cookie names
|
||||
if COOKIE_NAME_ACCESS_TOKEN in cookie_header:
|
||||
cookie_names_found.add(COOKIE_NAME_ACCESS_TOKEN)
|
||||
assert '""' in cookie_header or "=" in cookie_header # Empty value
|
||||
assert "Expires=Thu, 01 Jan 1970" in cookie_header # Expired
|
||||
elif COOKIE_NAME_CSRF_TOKEN in cookie_header:
|
||||
cookie_names_found.add(COOKIE_NAME_CSRF_TOKEN)
|
||||
assert '""' in cookie_header or "=" in cookie_header
|
||||
assert "Expires=Thu, 01 Jan 1970" in cookie_header
|
||||
elif COOKIE_NAME_REFRESH_TOKEN in cookie_header:
|
||||
cookie_names_found.add(COOKIE_NAME_REFRESH_TOKEN)
|
||||
assert '""' in cookie_header or "=" in cookie_header
|
||||
assert "Expires=Thu, 01 Jan 1970" in cookie_header
|
||||
|
||||
# Verify all three cookies are present
|
||||
assert len(cookie_names_found) == 3
|
||||
assert COOKIE_NAME_ACCESS_TOKEN in cookie_names_found
|
||||
assert COOKIE_NAME_CSRF_TOKEN in cookie_names_found
|
||||
assert COOKIE_NAME_REFRESH_TOKEN in cookie_names_found
|
||||
55
dify/api/tests/unit_tests/libs/test_file_utils.py
Normal file
55
dify/api/tests/unit_tests/libs/test_file_utils.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.file_utils import search_file_upwards
|
||||
|
||||
|
||||
def test_search_file_upwards_found_in_parent(tmp_path: Path):
|
||||
base = tmp_path / "a" / "b" / "c"
|
||||
base.mkdir(parents=True)
|
||||
|
||||
target = tmp_path / "a" / "target.txt"
|
||||
target.write_text("ok", encoding="utf-8")
|
||||
|
||||
found = search_file_upwards(base, "target.txt", max_search_parent_depth=5)
|
||||
assert found == target
|
||||
|
||||
|
||||
def test_search_file_upwards_found_in_current(tmp_path: Path):
|
||||
base = tmp_path / "x"
|
||||
base.mkdir()
|
||||
target = base / "here.txt"
|
||||
target.write_text("x", encoding="utf-8")
|
||||
|
||||
found = search_file_upwards(base, "here.txt", max_search_parent_depth=1)
|
||||
assert found == target
|
||||
|
||||
|
||||
def test_search_file_upwards_not_found_raises(tmp_path: Path):
|
||||
base = tmp_path / "m" / "n"
|
||||
base.mkdir(parents=True)
|
||||
with pytest.raises(ValueError) as exc:
|
||||
search_file_upwards(base, "missing.txt", max_search_parent_depth=3)
|
||||
# error message should contain file name and base path
|
||||
msg = str(exc.value)
|
||||
assert "missing.txt" in msg
|
||||
assert str(base) in msg
|
||||
|
||||
|
||||
def test_search_file_upwards_root_breaks_and_raises():
|
||||
# Using filesystem root triggers the 'break' branch (parent == current)
|
||||
with pytest.raises(ValueError):
|
||||
search_file_upwards(Path("/"), "__definitely_not_exists__.txt", max_search_parent_depth=1)
|
||||
|
||||
|
||||
def test_search_file_upwards_depth_limit_raises(tmp_path: Path):
|
||||
base = tmp_path / "a" / "b" / "c"
|
||||
base.mkdir(parents=True)
|
||||
target = tmp_path / "a" / "target.txt"
|
||||
target.write_text("ok", encoding="utf-8")
|
||||
# The file is 2 levels up from `c` (in `a`), but search depth is only 2.
|
||||
# The search path is `c` (depth 1) -> `b` (depth 2). The file is in `a` (would need depth 3).
|
||||
# So, this should not find the file and should raise an error.
|
||||
with pytest.raises(ValueError):
|
||||
search_file_upwards(base, "target.txt", max_search_parent_depth=2)
|
||||
123
dify/api/tests/unit_tests/libs/test_flask_utils.py
Normal file
123
dify/api/tests/unit_tests/libs/test_flask_utils.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import contextvars
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_login import LoginManager, UserMixin, current_user, login_user
|
||||
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
|
||||
class User(UserMixin):
|
||||
"""Simple User class for testing."""
|
||||
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
|
||||
def get_id(self) -> str:
|
||||
return self.id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def login_app(app: Flask) -> Flask:
|
||||
"""Set up a Flask app with flask-login."""
|
||||
# Set a secret key for the app
|
||||
app.config["SECRET_KEY"] = "test-secret-key"
|
||||
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
@login_manager.user_loader
|
||||
def load_user(user_id: str) -> User | None:
|
||||
if user_id == "test_user":
|
||||
return User("test_user")
|
||||
return None
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user() -> User:
|
||||
"""Create a test user."""
|
||||
return User("test_user")
|
||||
|
||||
|
||||
def test_current_user_not_accessible_across_threads(login_app: Flask, test_user: User):
|
||||
"""
|
||||
Test that current_user is not accessible in a different thread without preserve_flask_contexts.
|
||||
|
||||
This test demonstrates that without the preserve_flask_contexts, we cannot access
|
||||
current_user in a different thread, even with app_context.
|
||||
"""
|
||||
# Log in the user in the main thread
|
||||
with login_app.test_request_context():
|
||||
login_user(test_user)
|
||||
assert current_user.is_authenticated
|
||||
assert current_user.id == "test_user"
|
||||
|
||||
# Store the result of the thread execution
|
||||
result = {"user_accessible": True, "error": None}
|
||||
|
||||
# Define a function to run in a separate thread
|
||||
def check_user_in_thread():
|
||||
try:
|
||||
# Try to access current_user in a different thread with app_context
|
||||
with login_app.app_context():
|
||||
# This should fail because current_user is not accessible across threads
|
||||
# without preserve_flask_contexts
|
||||
result["user_accessible"] = current_user.is_authenticated
|
||||
except Exception as e:
|
||||
result["error"] = str(e)
|
||||
|
||||
# Run the function in a separate thread
|
||||
thread = threading.Thread(target=check_user_in_thread)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
# Verify that we got an error or current_user is not authenticated
|
||||
assert result["error"] is not None or (result["user_accessible"] is not None and not result["user_accessible"])
|
||||
|
||||
|
||||
def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask, test_user: User):
|
||||
"""
|
||||
Test that current_user is accessible in a different thread with preserve_flask_contexts.
|
||||
|
||||
This test demonstrates that with the preserve_flask_contexts, we can access
|
||||
current_user in a different thread.
|
||||
"""
|
||||
# Log in the user in the main thread
|
||||
with login_app.test_request_context():
|
||||
login_user(test_user)
|
||||
assert current_user.is_authenticated
|
||||
assert current_user.id == "test_user"
|
||||
|
||||
# Save the context variables
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
# Store the result of the thread execution
|
||||
result = {"user_accessible": False, "user_id": None, "error": None}
|
||||
|
||||
# Define a function to run in a separate thread
|
||||
def check_user_in_thread_with_manager():
|
||||
try:
|
||||
# Use preserve_flask_contexts to access current_user in a different thread
|
||||
with preserve_flask_contexts(login_app, context_vars):
|
||||
from flask_login import current_user
|
||||
|
||||
if current_user:
|
||||
result["user_accessible"] = True
|
||||
result["user_id"] = current_user.id
|
||||
else:
|
||||
result["user_accessible"] = False
|
||||
except Exception as e:
|
||||
result["error"] = str(e)
|
||||
|
||||
# Run the function in a separate thread
|
||||
thread = threading.Thread(target=check_user_in_thread_with_manager)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
# Verify that current_user is accessible and has the correct ID
|
||||
assert result["error"] is None
|
||||
assert result["user_accessible"] is True
|
||||
assert result["user_id"] == "test_user"
|
||||
65
dify/api/tests/unit_tests/libs/test_helper.py
Normal file
65
dify/api/tests/unit_tests/libs/test_helper.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
|
||||
from libs.helper import extract_tenant_id
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
class TestExtractTenantId:
|
||||
"""Test cases for the extract_tenant_id utility function."""
|
||||
|
||||
def test_extract_tenant_id_from_account_with_tenant(self):
|
||||
"""Test extracting tenant_id from Account with current_tenant_id."""
|
||||
# Create a mock Account object
|
||||
account = Account(name="test", email="test@example.com")
|
||||
# Mock the current_tenant_id property
|
||||
account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})()
|
||||
|
||||
tenant_id = extract_tenant_id(account)
|
||||
assert tenant_id == "account-tenant-123"
|
||||
|
||||
def test_extract_tenant_id_from_account_without_tenant(self):
|
||||
"""Test extracting tenant_id from Account without current_tenant_id."""
|
||||
# Create a mock Account object
|
||||
account = Account(name="test", email="test@example.com")
|
||||
account._current_tenant = None
|
||||
|
||||
tenant_id = extract_tenant_id(account)
|
||||
assert tenant_id is None
|
||||
|
||||
def test_extract_tenant_id_from_enduser_with_tenant(self):
|
||||
"""Test extracting tenant_id from EndUser with tenant_id."""
|
||||
# Create a mock EndUser object
|
||||
end_user = EndUser()
|
||||
end_user.tenant_id = "enduser-tenant-456"
|
||||
|
||||
tenant_id = extract_tenant_id(end_user)
|
||||
assert tenant_id == "enduser-tenant-456"
|
||||
|
||||
def test_extract_tenant_id_from_enduser_without_tenant(self):
|
||||
"""Test extracting tenant_id from EndUser without tenant_id."""
|
||||
# Create a mock EndUser object
|
||||
end_user = EndUser()
|
||||
end_user.tenant_id = None
|
||||
|
||||
tenant_id = extract_tenant_id(end_user)
|
||||
assert tenant_id is None
|
||||
|
||||
def test_extract_tenant_id_with_invalid_user_type(self):
|
||||
"""Test extracting tenant_id with invalid user type raises ValueError."""
|
||||
invalid_user = "not_a_user_object"
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(invalid_user)
|
||||
|
||||
def test_extract_tenant_id_with_none_user(self):
|
||||
"""Test extracting tenant_id with None user raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(None)
|
||||
|
||||
def test_extract_tenant_id_with_dict_user(self):
|
||||
"""Test extracting tenant_id with dict user raises ValueError."""
|
||||
dict_user = {"id": "123", "tenant_id": "456"}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(dict_user)
|
||||
109
dify/api/tests/unit_tests/libs/test_json_in_md_parser.py
Normal file
109
dify/api/tests/unit_tests/libs/test_json_in_md_parser.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import pytest
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from libs.json_in_md_parser import (
|
||||
parse_and_check_json_markdown,
|
||||
parse_json_markdown,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_json_markdown_triple_backticks_json():
|
||||
src = """
|
||||
```json
|
||||
{"a": 1, "b": "x"}
|
||||
```
|
||||
"""
|
||||
assert parse_json_markdown(src) == {"a": 1, "b": "x"}
|
||||
|
||||
|
||||
def test_parse_json_markdown_triple_backticks_generic():
|
||||
src = """
|
||||
```
|
||||
{"k": [1, 2, 3]}
|
||||
```
|
||||
"""
|
||||
assert parse_json_markdown(src) == {"k": [1, 2, 3]}
|
||||
|
||||
|
||||
def test_parse_json_markdown_single_backticks():
|
||||
src = '`{"x": true}`'
|
||||
assert parse_json_markdown(src) == {"x": True}
|
||||
|
||||
|
||||
def test_parse_json_markdown_braces_only():
|
||||
src = ' {\n \t"ok": "yes"\n} '
|
||||
assert parse_json_markdown(src) == {"ok": "yes"}
|
||||
|
||||
|
||||
def test_parse_json_markdown_not_found():
|
||||
with pytest.raises(ValueError):
|
||||
parse_json_markdown("no json here")
|
||||
|
||||
|
||||
def test_parse_and_check_json_markdown_missing_key():
|
||||
src = """
|
||||
```
|
||||
{"present": 1}
|
||||
```
|
||||
"""
|
||||
with pytest.raises(OutputParserError) as exc:
|
||||
parse_and_check_json_markdown(src, ["present", "missing"])
|
||||
assert "expected key `missing`" in str(exc.value)
|
||||
|
||||
|
||||
def test_parse_and_check_json_markdown_invalid_json():
|
||||
src = """
|
||||
```json
|
||||
{invalid json}
|
||||
```
|
||||
"""
|
||||
with pytest.raises(OutputParserError) as exc:
|
||||
parse_and_check_json_markdown(src, [])
|
||||
assert "got invalid json object" in str(exc.value)
|
||||
|
||||
|
||||
def test_parse_and_check_json_markdown_success():
|
||||
src = """
|
||||
```json
|
||||
{"present": 1, "other": 2}
|
||||
```
|
||||
"""
|
||||
obj = parse_and_check_json_markdown(src, ["present"])
|
||||
assert obj == {"present": 1, "other": 2}
|
||||
|
||||
|
||||
def test_parse_and_check_json_markdown_multiple_blocks_fails():
|
||||
src = """
|
||||
```json
|
||||
{"a": 1}
|
||||
```
|
||||
Some text
|
||||
```json
|
||||
{"b": 2}
|
||||
```
|
||||
"""
|
||||
# The current implementation is greedy and will match from the first
|
||||
# opening fence to the last closing fence, causing JSON decode failure.
|
||||
with pytest.raises(OutputParserError):
|
||||
parse_and_check_json_markdown(src, [])
|
||||
|
||||
|
||||
def test_parse_and_check_json_markdown_handles_think_fenced_and_raw_variants():
|
||||
expected = {"keywords": ["2"], "category_id": "2", "category_name": "2"}
|
||||
cases = [
|
||||
"""
|
||||
```json
|
||||
[{"keywords": ["2"], "category_id": "2", "category_name": "2"}]
|
||||
```, error: Expecting value: line 1 column 1 (char 0)
|
||||
""",
|
||||
"""
|
||||
```json
|
||||
{"keywords": ["2"], "category_id": "2", "category_name": "2"}
|
||||
```, error: Extra data: line 2 column 5 (char 66)
|
||||
""",
|
||||
'{"keywords": ["2"], "category_id": "2", "category_name": "2"}',
|
||||
'[{"keywords": ["2"], "category_id": "2", "category_name": "2"}]',
|
||||
]
|
||||
for src in cases:
|
||||
obj = parse_and_check_json_markdown(src, ["keywords", "category_id", "category_name"])
|
||||
assert obj == expected
|
||||
63
dify/api/tests/unit_tests/libs/test_jwt_imports.py
Normal file
63
dify/api/tests/unit_tests/libs/test_jwt_imports.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Test PyJWT import paths to catch changes in library structure."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestPyJWTImports:
|
||||
"""Test PyJWT import paths used throughout the codebase."""
|
||||
|
||||
def test_invalid_token_error_import(self):
|
||||
"""Test that InvalidTokenError can be imported as used in login controller."""
|
||||
# This test verifies the import path used in controllers/web/login.py:2
|
||||
# If PyJWT changes this import path, this test will fail early
|
||||
try:
|
||||
from jwt import InvalidTokenError
|
||||
|
||||
# Verify it's the correct exception class
|
||||
assert issubclass(InvalidTokenError, Exception)
|
||||
|
||||
# Test that it can be instantiated
|
||||
error = InvalidTokenError("test error")
|
||||
assert str(error) == "test error"
|
||||
|
||||
except ImportError as e:
|
||||
pytest.fail(f"Failed to import InvalidTokenError from jwt: {e}")
|
||||
|
||||
def test_jwt_exceptions_import(self):
|
||||
"""Test that jwt.exceptions imports work as expected."""
|
||||
# Alternative import path that might be used
|
||||
try:
|
||||
# Verify it's the same class as the direct import
|
||||
from jwt import InvalidTokenError
|
||||
from jwt.exceptions import InvalidTokenError as InvalidTokenErrorAlt
|
||||
|
||||
assert InvalidTokenError is InvalidTokenErrorAlt
|
||||
|
||||
except ImportError as e:
|
||||
pytest.fail(f"Failed to import InvalidTokenError from jwt.exceptions: {e}")
|
||||
|
||||
def test_other_jwt_exceptions_available(self):
|
||||
"""Test that other common JWT exceptions are available."""
|
||||
# Test other exceptions that might be used in the codebase
|
||||
try:
|
||||
from jwt import DecodeError, ExpiredSignatureError, InvalidSignatureError
|
||||
|
||||
# Verify they are exception classes
|
||||
assert issubclass(DecodeError, Exception)
|
||||
assert issubclass(ExpiredSignatureError, Exception)
|
||||
assert issubclass(InvalidSignatureError, Exception)
|
||||
|
||||
except ImportError as e:
|
||||
pytest.fail(f"Failed to import JWT exceptions: {e}")
|
||||
|
||||
def test_jwt_main_functions_available(self):
|
||||
"""Test that main JWT functions are available."""
|
||||
try:
|
||||
from jwt import decode, encode
|
||||
|
||||
# Verify they are callable
|
||||
assert callable(decode)
|
||||
assert callable(encode)
|
||||
|
||||
except ImportError as e:
|
||||
pytest.fail(f"Failed to import JWT main functions: {e}")
|
||||
243
dify/api/tests/unit_tests/libs/test_login.py
Normal file
243
dify/api/tests/unit_tests/libs/test_login.py
Normal file
@@ -0,0 +1,243 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from flask_login import LoginManager, UserMixin
|
||||
|
||||
from libs.login import _get_user, current_user, login_required
|
||||
|
||||
|
||||
class MockUser(UserMixin):
|
||||
"""Mock user class for testing."""
|
||||
|
||||
def __init__(self, id: str, is_authenticated: bool = True):
|
||||
self.id = id
|
||||
self._is_authenticated = is_authenticated
|
||||
|
||||
@property
|
||||
def is_authenticated(self):
|
||||
return self._is_authenticated
|
||||
|
||||
|
||||
def mock_csrf_check(*args, **kwargs):
|
||||
return
|
||||
|
||||
|
||||
class TestLoginRequired:
|
||||
"""Test cases for login_required decorator."""
|
||||
|
||||
@pytest.fixture
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def setup_app(self, app: Flask):
|
||||
"""Set up Flask app with login manager."""
|
||||
# Initialize login manager
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
# Mock unauthorized handler
|
||||
login_manager.unauthorized = MagicMock(return_value="Unauthorized")
|
||||
|
||||
# Add a dummy user loader to prevent exceptions
|
||||
@login_manager.user_loader
|
||||
def load_user(user_id):
|
||||
return None
|
||||
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_authenticated_user_can_access_protected_view(self, setup_app: Flask):
|
||||
"""Test that authenticated users can access protected views."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context():
|
||||
# Mock authenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_unauthenticated_user_cannot_access_protected_view(self, setup_app: Flask):
|
||||
"""Test that unauthenticated users are redirected."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context():
|
||||
# Mock unauthenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Unauthorized"
|
||||
setup_app.login_manager.unauthorized.assert_called_once()
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_login_disabled_allows_unauthenticated_access(self, setup_app: Flask):
|
||||
"""Test that LOGIN_DISABLED config bypasses authentication."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context():
|
||||
# Mock unauthenticated user and LOGIN_DISABLED
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
with patch("libs.login.dify_config") as mock_config:
|
||||
mock_config.LOGIN_DISABLED = True
|
||||
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
# Ensure unauthorized was not called
|
||||
setup_app.login_manager.unauthorized.assert_not_called()
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_options_request_bypasses_authentication(self, setup_app: Flask):
|
||||
"""Test that OPTIONS requests are exempt from authentication."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context(method="OPTIONS"):
|
||||
# Mock unauthenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
# Ensure unauthorized was not called
|
||||
setup_app.login_manager.unauthorized.assert_not_called()
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_flask_2_compatibility(self, setup_app: Flask):
|
||||
"""Test Flask 2.x compatibility with ensure_sync."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
# Mock Flask 2.x ensure_sync
|
||||
setup_app.ensure_sync = MagicMock(return_value=lambda: "Synced content")
|
||||
|
||||
with setup_app.test_request_context():
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Synced content"
|
||||
setup_app.ensure_sync.assert_called_once()
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_flask_1_compatibility(self, setup_app: Flask):
|
||||
"""Test Flask 1.x compatibility without ensure_sync."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
# Remove ensure_sync to simulate Flask 1.x
|
||||
if hasattr(setup_app, "ensure_sync"):
|
||||
delattr(setup_app, "ensure_sync")
|
||||
|
||||
with setup_app.test_request_context():
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
|
||||
|
||||
class TestGetUser:
|
||||
"""Test cases for _get_user function."""
|
||||
|
||||
def test_get_user_returns_user_from_g(self, app: Flask):
|
||||
"""Test that _get_user returns user from g._login_user."""
|
||||
mock_user = MockUser("test_user")
|
||||
|
||||
with app.test_request_context():
|
||||
g._login_user = mock_user
|
||||
user = _get_user()
|
||||
assert user == mock_user
|
||||
assert user.id == "test_user"
|
||||
|
||||
def test_get_user_loads_user_if_not_in_g(self, app: Flask):
|
||||
"""Test that _get_user loads user if not already in g."""
|
||||
mock_user = MockUser("test_user")
|
||||
|
||||
# Mock login manager
|
||||
login_manager = MagicMock()
|
||||
login_manager._load_user = MagicMock()
|
||||
app.login_manager = login_manager
|
||||
|
||||
with app.test_request_context():
|
||||
# Simulate _load_user setting g._login_user
|
||||
def side_effect():
|
||||
g._login_user = mock_user
|
||||
|
||||
login_manager._load_user.side_effect = side_effect
|
||||
|
||||
user = _get_user()
|
||||
assert user == mock_user
|
||||
login_manager._load_user.assert_called_once()
|
||||
|
||||
def test_get_user_returns_none_without_request_context(self, app: Flask):
|
||||
"""Test that _get_user returns None outside request context."""
|
||||
# Outside of request context
|
||||
user = _get_user()
|
||||
assert user is None
|
||||
|
||||
|
||||
class TestCurrentUser:
|
||||
"""Test cases for current_user proxy."""
|
||||
|
||||
def test_current_user_proxy_returns_authenticated_user(self, app: Flask):
|
||||
"""Test that current_user proxy returns authenticated user."""
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
|
||||
with app.test_request_context():
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
assert current_user.id == "test_user"
|
||||
assert current_user.is_authenticated is True
|
||||
|
||||
def test_current_user_proxy_returns_none_when_no_user(self, app: Flask):
|
||||
"""Test that current_user proxy handles None user."""
|
||||
with app.test_request_context():
|
||||
with patch("libs.login._get_user", return_value=None):
|
||||
# When _get_user returns None, accessing attributes should fail
|
||||
# or current_user should evaluate to falsy
|
||||
try:
|
||||
# Try to access an attribute that would exist on a real user
|
||||
_ = current_user.id
|
||||
pytest.fail("Should have raised AttributeError")
|
||||
except AttributeError:
|
||||
# This is expected when current_user is None
|
||||
pass
|
||||
|
||||
def test_current_user_proxy_thread_safety(self, app: Flask):
|
||||
"""Test that current_user proxy is thread-safe."""
|
||||
import threading
|
||||
|
||||
results = {}
|
||||
|
||||
def check_user_in_thread(user_id: str, index: int):
|
||||
with app.test_request_context():
|
||||
mock_user = MockUser(user_id)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
results[index] = current_user.id
|
||||
|
||||
# Create multiple threads with different users
|
||||
threads = []
|
||||
for i in range(5):
|
||||
thread = threading.Thread(target=check_user_in_thread, args=(f"user_{i}", i))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify each thread got its own user
|
||||
for i in range(5):
|
||||
assert results[i] == f"user_{i}"
|
||||
19
dify/api/tests/unit_tests/libs/test_oauth_base.py
Normal file
19
dify/api/tests/unit_tests/libs/test_oauth_base.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import pytest
|
||||
|
||||
from libs.oauth import OAuth
|
||||
|
||||
|
||||
def test_oauth_base_methods_raise_not_implemented():
|
||||
oauth = OAuth(client_id="id", client_secret="sec", redirect_uri="uri")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth.get_authorization_url()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth.get_access_token("code")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth.get_raw_user_info("token")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth._transform_user_info({})
|
||||
249
dify/api/tests/unit_tests/libs/test_oauth_clients.py
Normal file
249
dify/api/tests/unit_tests/libs/test_oauth_clients.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import urllib.parse
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
|
||||
|
||||
class BaseOAuthTest:
|
||||
"""Base class for OAuth provider tests with common fixtures"""
|
||||
|
||||
@pytest.fixture
|
||||
def oauth_config(self):
|
||||
return {
|
||||
"client_id": "test_client_id",
|
||||
"client_secret": "test_client_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
response = MagicMock()
|
||||
response.json.return_value = {}
|
||||
return response
|
||||
|
||||
def parse_auth_url(self, url):
|
||||
"""Helper to parse authorization URL"""
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
params = urllib.parse.parse_qs(parsed.query)
|
||||
return parsed, params
|
||||
|
||||
|
||||
class TestGitHubOAuth(BaseOAuthTest):
|
||||
@pytest.fixture
|
||||
def oauth(self, oauth_config):
|
||||
return GitHubOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("invite_token", "expected_state"),
|
||||
[
|
||||
(None, None),
|
||||
("test_invite_token", "test_invite_token"),
|
||||
("", None),
|
||||
],
|
||||
)
|
||||
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
|
||||
url = oauth.get_authorization_url(invite_token)
|
||||
parsed, params = self.parse_auth_url(url)
|
||||
|
||||
assert parsed.scheme == "https"
|
||||
assert parsed.netloc == "github.com"
|
||||
assert parsed.path == "/login/oauth/authorize"
|
||||
assert params["client_id"][0] == oauth_config["client_id"]
|
||||
assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
|
||||
assert params["scope"][0] == "user:email"
|
||||
|
||||
if expected_state:
|
||||
assert params["state"][0] == expected_state
|
||||
else:
|
||||
assert "state" not in params
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("response_data", "expected_token", "should_raise"),
|
||||
[
|
||||
({"access_token": "test_token"}, "test_token", False),
|
||||
({"error": "invalid_grant"}, None, True),
|
||||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("httpx.post")
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
mock_response.json.return_value = response_data
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth.get_access_token("test_code")
|
||||
assert "Error in GitHub OAuth" in str(exc_info.value)
|
||||
else:
|
||||
token = oauth.get_access_token("test_code")
|
||||
assert token == expected_token
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("user_data", "email_data", "expected_email"),
|
||||
[
|
||||
# User with primary email
|
||||
(
|
||||
{"id": 12345, "login": "testuser", "name": "Test User"},
|
||||
[
|
||||
{"email": "secondary@example.com", "primary": False},
|
||||
{"email": "primary@example.com", "primary": True},
|
||||
],
|
||||
"primary@example.com",
|
||||
),
|
||||
# User with no emails - fallback to noreply
|
||||
({"id": 12345, "login": "testuser", "name": "Test User"}, [], "12345+testuser@users.noreply.github.com"),
|
||||
# User with only secondary email - fallback to noreply
|
||||
(
|
||||
{"id": 12345, "login": "testuser", "name": "Test User"},
|
||||
[{"email": "secondary@example.com", "primary": False}],
|
||||
"12345+testuser@users.noreply.github.com",
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("httpx.get")
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
|
||||
user_response = MagicMock()
|
||||
user_response.json.return_value = user_data
|
||||
|
||||
email_response = MagicMock()
|
||||
email_response.json.return_value = email_data
|
||||
|
||||
mock_get.side_effect = [user_response, email_response]
|
||||
|
||||
user_info = oauth.get_user_info("test_token")
|
||||
|
||||
assert user_info.id == str(user_data["id"])
|
||||
assert user_info.name == user_data["name"]
|
||||
assert user_info.email == expected_email
|
||||
|
||||
@patch("httpx.get")
|
||||
def test_should_handle_network_errors(self, mock_get, oauth):
|
||||
mock_get.side_effect = httpx.RequestError("Network error")
|
||||
|
||||
with pytest.raises(httpx.RequestError):
|
||||
oauth.get_raw_user_info("test_token")
|
||||
|
||||
|
||||
class TestGoogleOAuth(BaseOAuthTest):
|
||||
@pytest.fixture
|
||||
def oauth(self, oauth_config):
|
||||
return GoogleOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("invite_token", "expected_state"),
|
||||
[
|
||||
(None, None),
|
||||
("test_invite_token", "test_invite_token"),
|
||||
("", None),
|
||||
],
|
||||
)
|
||||
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
|
||||
url = oauth.get_authorization_url(invite_token)
|
||||
parsed, params = self.parse_auth_url(url)
|
||||
|
||||
assert parsed.scheme == "https"
|
||||
assert parsed.netloc == "accounts.google.com"
|
||||
assert parsed.path == "/o/oauth2/v2/auth"
|
||||
assert params["client_id"][0] == oauth_config["client_id"]
|
||||
assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
|
||||
assert params["response_type"][0] == "code"
|
||||
assert params["scope"][0] == "openid email"
|
||||
|
||||
if expected_state:
|
||||
assert params["state"][0] == expected_state
|
||||
else:
|
||||
assert "state" not in params
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("response_data", "expected_token", "should_raise"),
|
||||
[
|
||||
({"access_token": "test_token"}, "test_token", False),
|
||||
({"error": "invalid_grant"}, None, True),
|
||||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("httpx.post")
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
mock_response.json.return_value = response_data
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth.get_access_token("test_code")
|
||||
assert "Error in Google OAuth" in str(exc_info.value)
|
||||
else:
|
||||
token = oauth.get_access_token("test_code")
|
||||
assert token == expected_token
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
oauth._TOKEN_URL,
|
||||
data={
|
||||
"client_id": oauth_config["client_id"],
|
||||
"client_secret": oauth_config["client_secret"],
|
||||
"code": "test_code",
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": oauth_config["redirect_uri"],
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("user_data", "expected_name"),
|
||||
[
|
||||
({"sub": "123", "email": "test@example.com", "email_verified": True}, ""),
|
||||
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
|
||||
],
|
||||
)
|
||||
@patch("httpx.get")
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
|
||||
mock_response.json.return_value = user_data
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
user_info = oauth.get_user_info("test_token")
|
||||
|
||||
assert user_info.id == user_data["sub"]
|
||||
assert user_info.name == expected_name
|
||||
assert user_info.email == user_data["email"]
|
||||
|
||||
mock_get.assert_called_once_with(oauth._USER_INFO_URL, headers={"Authorization": "Bearer test_token"})
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exception_type",
|
||||
[
|
||||
httpx.HTTPError,
|
||||
httpx.ConnectError,
|
||||
httpx.TimeoutException,
|
||||
],
|
||||
)
|
||||
@patch("httpx.get")
|
||||
def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = exception_type("Error")
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(exception_type):
|
||||
oauth.get_raw_user_info("invalid_token")
|
||||
|
||||
|
||||
class TestOAuthUserInfo:
|
||||
@pytest.mark.parametrize(
|
||||
"user_data",
|
||||
[
|
||||
{"id": "123", "name": "Test User", "email": "test@example.com"},
|
||||
{"id": "456", "name": "", "email": "user@domain.com"},
|
||||
{"id": "789", "name": "Another User", "email": "another@test.org"},
|
||||
],
|
||||
)
|
||||
def test_should_create_user_info_dataclass(self, user_data):
|
||||
user_info = OAuthUserInfo(**user_data)
|
||||
|
||||
assert user_info.id == user_data["id"]
|
||||
assert user_info.name == user_data["name"]
|
||||
assert user_info.email == user_data["email"]
|
||||
25
dify/api/tests/unit_tests/libs/test_orjson.py
Normal file
25
dify/api/tests/unit_tests/libs/test_orjson.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from libs.orjson import orjson_dumps
|
||||
|
||||
|
||||
def test_orjson_dumps_round_trip_basic():
|
||||
obj = {"a": 1, "b": [1, 2, 3], "c": {"d": True}}
|
||||
s = orjson_dumps(obj)
|
||||
assert orjson.loads(s) == obj
|
||||
|
||||
|
||||
def test_orjson_dumps_with_unicode_and_indent():
|
||||
obj = {"msg": "你好,Dify"}
|
||||
s = orjson_dumps(obj, option=orjson.OPT_INDENT_2)
|
||||
# contains indentation newline/spaces
|
||||
assert "\n" in s
|
||||
assert orjson.loads(s) == obj
|
||||
|
||||
|
||||
def test_orjson_dumps_non_utf8_encoding_fails():
|
||||
obj = {"msg": "你好"}
|
||||
# orjson.dumps() always produces UTF-8 bytes; decoding with non-UTF8 fails.
|
||||
with pytest.raises(UnicodeDecodeError):
|
||||
orjson_dumps(obj, encoding="ascii")
|
||||
58
dify/api/tests/unit_tests/libs/test_pandas.py
Normal file
58
dify/api/tests/unit_tests/libs/test_pandas.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def test_pandas_csv(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]}
|
||||
df1 = pd.DataFrame(data)
|
||||
|
||||
# write to csv file
|
||||
csv_file_path = tmp_path.joinpath("example.csv")
|
||||
df1.to_csv(csv_file_path, index=False)
|
||||
|
||||
# read from csv file
|
||||
df2 = pd.read_csv(csv_file_path, on_bad_lines="skip")
|
||||
assert df2[df2.columns[0]].to_list() == data["col1"]
|
||||
assert df2[df2.columns[1]].to_list() == data["col2"]
|
||||
|
||||
|
||||
def test_pandas_xlsx(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]}
|
||||
df1 = pd.DataFrame(data)
|
||||
|
||||
# write to xlsx file
|
||||
xlsx_file_path = tmp_path.joinpath("example.xlsx")
|
||||
df1.to_excel(xlsx_file_path, index=False)
|
||||
|
||||
# read from xlsx file
|
||||
df2 = pd.read_excel(xlsx_file_path)
|
||||
assert df2[df2.columns[0]].to_list() == data["col1"]
|
||||
assert df2[df2.columns[1]].to_list() == data["col2"]
|
||||
|
||||
|
||||
def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
data1 = {"col1": [1, 2, 3, 4, 5], "col2": ["A", "B", "C", "D", "E"]}
|
||||
df1 = pd.DataFrame(data1)
|
||||
|
||||
data2 = {"col1": [6, 7, 8, 9, 10], "col2": ["F", "G", "H", "I", "J"]}
|
||||
df2 = pd.DataFrame(data2)
|
||||
|
||||
# write to xlsx file with sheets
|
||||
xlsx_file_path = tmp_path.joinpath("example_with_sheets.xlsx")
|
||||
sheet1 = "Sheet1"
|
||||
sheet2 = "Sheet2"
|
||||
with pd.ExcelWriter(xlsx_file_path) as excel_writer:
|
||||
df1.to_excel(excel_writer, sheet_name=sheet1, index=False)
|
||||
df2.to_excel(excel_writer, sheet_name=sheet2, index=False)
|
||||
|
||||
# read from xlsx file with sheets
|
||||
with pd.ExcelFile(xlsx_file_path) as excel_file:
|
||||
df1 = pd.read_excel(excel_file, sheet_name=sheet1)
|
||||
assert df1[df1.columns[0]].to_list() == data1["col1"]
|
||||
assert df1[df1.columns[1]].to_list() == data1["col2"]
|
||||
|
||||
df2 = pd.read_excel(excel_file, sheet_name=sheet2)
|
||||
assert df2[df2.columns[0]].to_list() == data2["col1"]
|
||||
assert df2[df2.columns[1]].to_list() == data2["col2"]
|
||||
205
dify/api/tests/unit_tests/libs/test_passport.py
Normal file
205
dify/api/tests/unit_tests/libs/test_passport.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from libs.passport import PassportService
|
||||
|
||||
|
||||
class TestPassportService:
|
||||
"""Test PassportService JWT operations"""
|
||||
|
||||
@pytest.fixture
|
||||
def passport_service(self):
|
||||
"""Create PassportService instance with test secret key"""
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "test-secret-key-for-testing"
|
||||
return PassportService()
|
||||
|
||||
@pytest.fixture
|
||||
def another_passport_service(self):
|
||||
"""Create another PassportService instance with different secret key"""
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "another-secret-key-for-testing"
|
||||
return PassportService()
|
||||
|
||||
# Core functionality tests
|
||||
def test_should_issue_and_verify_token(self, passport_service):
|
||||
"""Test complete JWT lifecycle: issue and verify"""
|
||||
payload = {"user_id": "123", "app_code": "test-app"}
|
||||
token = passport_service.issue(payload)
|
||||
|
||||
# Verify token format
|
||||
assert isinstance(token, str)
|
||||
assert len(token.split(".")) == 3 # JWT format: header.payload.signature
|
||||
|
||||
# Verify token content
|
||||
decoded = passport_service.verify(token)
|
||||
assert decoded == payload
|
||||
|
||||
def test_should_handle_different_payload_types(self, passport_service):
|
||||
"""Test issuing and verifying tokens with different payload types"""
|
||||
test_cases = [
|
||||
{"string": "value"},
|
||||
{"number": 42},
|
||||
{"float": 3.14},
|
||||
{"boolean": True},
|
||||
{"null": None},
|
||||
{"array": [1, 2, 3]},
|
||||
{"nested": {"key": "value"}},
|
||||
{"unicode": "中文测试"},
|
||||
{"emoji": "🔐"},
|
||||
{}, # Empty payload
|
||||
]
|
||||
|
||||
for payload in test_cases:
|
||||
token = passport_service.issue(payload)
|
||||
decoded = passport_service.verify(token)
|
||||
assert decoded == payload
|
||||
|
||||
# Security tests
|
||||
def test_should_reject_modified_token(self, passport_service):
|
||||
"""Test that any modification to token invalidates it"""
|
||||
token = passport_service.issue({"user": "test"})
|
||||
|
||||
# Test multiple modification points
|
||||
test_positions = [0, len(token) // 3, len(token) // 2, len(token) - 1]
|
||||
|
||||
for pos in test_positions:
|
||||
if pos < len(token) and token[pos] != ".":
|
||||
# Change one character
|
||||
tampered = token[:pos] + ("X" if token[pos] != "X" else "Y") + token[pos + 1 :]
|
||||
with pytest.raises(Unauthorized):
|
||||
passport_service.verify(tampered)
|
||||
|
||||
def test_should_reject_token_with_different_secret_key(self, passport_service, another_passport_service):
|
||||
"""Test key isolation - token from one service should not work with another"""
|
||||
payload = {"user_id": "123", "app_code": "test-app"}
|
||||
token = passport_service.issue(payload)
|
||||
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
another_passport_service.verify(token)
|
||||
assert str(exc_info.value) == "401 Unauthorized: Invalid token signature."
|
||||
|
||||
def test_should_use_hs256_algorithm(self, passport_service):
|
||||
"""Test that HS256 algorithm is used for signing"""
|
||||
payload = {"test": "data"}
|
||||
token = passport_service.issue(payload)
|
||||
|
||||
# Decode header without relying on JWT internals
|
||||
# Use jwt.get_unverified_header which is a public API
|
||||
header = jwt.get_unverified_header(token)
|
||||
assert header["alg"] == "HS256"
|
||||
|
||||
def test_should_reject_token_with_wrong_algorithm(self, passport_service):
|
||||
"""Test rejection of token signed with different algorithm"""
|
||||
payload = {"user_id": "123"}
|
||||
|
||||
# Create token with different algorithm
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "test-secret-key-for-testing"
|
||||
# Create token with HS512 instead of HS256
|
||||
wrong_alg_token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS512")
|
||||
|
||||
# Should fail because service expects HS256
|
||||
# InvalidAlgorithmError is now caught by PyJWTError handler
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
passport_service.verify(wrong_alg_token)
|
||||
assert str(exc_info.value) == "401 Unauthorized: Invalid token."
|
||||
|
||||
# Exception handling tests
|
||||
def test_should_handle_invalid_tokens(self, passport_service):
|
||||
"""Test handling of various invalid token formats"""
|
||||
invalid_tokens = [
|
||||
("not.a.token", "Invalid token."),
|
||||
("invalid-jwt-format", "Invalid token."),
|
||||
("xxx.yyy.zzz", "Invalid token."),
|
||||
("a.b", "Invalid token."), # Missing signature
|
||||
("", "Invalid token."), # Empty string
|
||||
(" ", "Invalid token."), # Whitespace
|
||||
(None, "Invalid token."), # None value
|
||||
# Malformed base64
|
||||
("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.INVALID_BASE64!@#$.signature", "Invalid token."),
|
||||
]
|
||||
|
||||
for invalid_token, expected_message in invalid_tokens:
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
passport_service.verify(invalid_token)
|
||||
assert expected_message in str(exc_info.value)
|
||||
|
||||
def test_should_reject_expired_token(self, passport_service):
|
||||
"""Test rejection of expired token"""
|
||||
past_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
payload = {"user_id": "123", "exp": past_time.timestamp()}
|
||||
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "test-secret-key-for-testing"
|
||||
token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS256")
|
||||
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
passport_service.verify(token)
|
||||
assert str(exc_info.value) == "401 Unauthorized: Token has expired."
|
||||
|
||||
# Configuration tests
|
||||
def test_should_handle_empty_secret_key(self):
|
||||
"""Test behavior when SECRET_KEY is empty"""
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = ""
|
||||
service = PassportService()
|
||||
|
||||
# Empty secret key should still work but is insecure
|
||||
payload = {"test": "data"}
|
||||
token = service.issue(payload)
|
||||
decoded = service.verify(token)
|
||||
assert decoded == payload
|
||||
|
||||
def test_should_handle_none_secret_key(self):
|
||||
"""Test behavior when SECRET_KEY is None"""
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = None
|
||||
service = PassportService()
|
||||
|
||||
payload = {"test": "data"}
|
||||
# JWT library will raise TypeError when secret is None
|
||||
with pytest.raises((TypeError, jwt.exceptions.InvalidKeyError)):
|
||||
service.issue(payload)
|
||||
|
||||
# Boundary condition tests
|
||||
def test_should_handle_large_payload(self, passport_service):
|
||||
"""Test handling of large payload"""
|
||||
# Test with 100KB instead of 1MB for faster tests
|
||||
large_data = "x" * (100 * 1024)
|
||||
payload = {"data": large_data}
|
||||
|
||||
token = passport_service.issue(payload)
|
||||
decoded = passport_service.verify(token)
|
||||
|
||||
assert decoded["data"] == large_data
|
||||
|
||||
def test_should_handle_special_characters_in_payload(self, passport_service):
|
||||
"""Test handling of special characters in payload"""
|
||||
special_payloads = [
|
||||
{"special": "!@#$%^&*()"},
|
||||
{"quotes": 'He said "Hello"'},
|
||||
{"backslash": "path\\to\\file"},
|
||||
{"newline": "line1\nline2"},
|
||||
{"unicode": "🔐🔑🛡️"},
|
||||
{"mixed": "Test123!@#中文🔐"},
|
||||
]
|
||||
|
||||
for payload in special_payloads:
|
||||
token = passport_service.issue(payload)
|
||||
decoded = passport_service.verify(token)
|
||||
assert decoded == payload
|
||||
|
||||
def test_should_catch_generic_pyjwt_errors(self, passport_service):
|
||||
"""Test that generic PyJWTError exceptions are caught and converted to Unauthorized"""
|
||||
# Mock jwt.decode to raise a generic PyJWTError
|
||||
with patch("libs.passport.jwt.decode") as mock_decode:
|
||||
mock_decode.side_effect = jwt.exceptions.PyJWTError("Generic JWT error")
|
||||
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
passport_service.verify("some-token")
|
||||
assert str(exc_info.value) == "401 Unauthorized: Invalid token."
|
||||
74
dify/api/tests/unit_tests/libs/test_password.py
Normal file
74
dify/api/tests/unit_tests/libs/test_password.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import base64
|
||||
import binascii
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.password import compare_password, hash_password, valid_password
|
||||
|
||||
|
||||
class TestValidPassword:
|
||||
"""Test password format validation"""
|
||||
|
||||
def test_should_accept_valid_passwords(self):
|
||||
"""Test accepting valid password formats"""
|
||||
assert valid_password("password123") == "password123"
|
||||
assert valid_password("test1234") == "test1234"
|
||||
assert valid_password("Test123456") == "Test123456"
|
||||
|
||||
def test_should_reject_invalid_passwords(self):
|
||||
"""Test rejecting invalid password formats"""
|
||||
# Too short
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
valid_password("abc123")
|
||||
assert "Password must contain letters and numbers" in str(exc_info.value)
|
||||
|
||||
# No numbers
|
||||
with pytest.raises(ValueError):
|
||||
valid_password("abcdefgh")
|
||||
|
||||
# No letters
|
||||
with pytest.raises(ValueError):
|
||||
valid_password("12345678")
|
||||
|
||||
# Empty
|
||||
with pytest.raises(ValueError):
|
||||
valid_password("")
|
||||
|
||||
|
||||
class TestPasswordHashing:
|
||||
"""Test password hashing and comparison"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test data"""
|
||||
self.password = "test123password"
|
||||
self.salt = os.urandom(16)
|
||||
self.salt_base64 = base64.b64encode(self.salt).decode()
|
||||
|
||||
password_hash = hash_password(self.password, self.salt)
|
||||
self.password_hash_base64 = base64.b64encode(password_hash).decode()
|
||||
|
||||
def test_should_verify_correct_password(self):
|
||||
"""Test correct password verification"""
|
||||
result = compare_password(self.password, self.password_hash_base64, self.salt_base64)
|
||||
assert result is True
|
||||
|
||||
def test_should_reject_wrong_password(self):
|
||||
"""Test rejection of incorrect passwords"""
|
||||
result = compare_password("wrongpassword", self.password_hash_base64, self.salt_base64)
|
||||
assert result is False
|
||||
|
||||
def test_should_handle_invalid_base64(self):
|
||||
"""Test handling of invalid base64 data"""
|
||||
# Invalid base64 hash
|
||||
with pytest.raises(binascii.Error):
|
||||
compare_password(self.password, "invalid_base64!", self.salt_base64)
|
||||
|
||||
# Invalid base64 salt
|
||||
with pytest.raises(binascii.Error):
|
||||
compare_password(self.password, self.password_hash_base64, "invalid_base64!")
|
||||
|
||||
def test_should_be_case_sensitive(self):
|
||||
"""Test password case sensitivity"""
|
||||
result = compare_password(self.password.upper(), self.password_hash_base64, self.salt_base64)
|
||||
assert result is False
|
||||
29
dify/api/tests/unit_tests/libs/test_rsa.py
Normal file
29
dify/api/tests/unit_tests/libs/test_rsa.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import rsa as pyrsa
|
||||
from Crypto.PublicKey import RSA
|
||||
|
||||
from libs import gmpy2_pkcs10aep_cipher
|
||||
|
||||
|
||||
def test_gmpy2_pkcs10aep_cipher():
|
||||
rsa_key_pair = pyrsa.newkeys(2048)
|
||||
public_key = rsa_key_pair[0].save_pkcs1()
|
||||
private_key = rsa_key_pair[1].save_pkcs1()
|
||||
|
||||
public_rsa_key = RSA.import_key(public_key)
|
||||
public_cipher_rsa2 = gmpy2_pkcs10aep_cipher.new(public_rsa_key)
|
||||
|
||||
private_rsa_key = RSA.import_key(private_key)
|
||||
private_cipher_rsa = gmpy2_pkcs10aep_cipher.new(private_rsa_key)
|
||||
|
||||
raw_text = "raw_text"
|
||||
raw_text_bytes = raw_text.encode()
|
||||
|
||||
# RSA encryption by public key and decryption by private key
|
||||
encrypted_by_pub_key = public_cipher_rsa2.encrypt(message=raw_text_bytes)
|
||||
decrypted_by_pub_key = private_cipher_rsa.decrypt(encrypted_by_pub_key)
|
||||
assert decrypted_by_pub_key == raw_text_bytes
|
||||
|
||||
# RSA encryption and decryption by private key
|
||||
encrypted_by_private_key = private_cipher_rsa.encrypt(message=raw_text_bytes)
|
||||
decrypted_by_private_key = private_cipher_rsa.decrypt(encrypted_by_private_key)
|
||||
assert decrypted_by_private_key == raw_text_bytes
|
||||
411
dify/api/tests/unit_tests/libs/test_schedule_utils_enhanced.py
Normal file
411
dify/api/tests/unit_tests/libs/test_schedule_utils_enhanced.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
Enhanced schedule_utils tests for new cron syntax support.
|
||||
|
||||
These tests verify that the backend schedule_utils functions properly support
|
||||
the enhanced cron syntax introduced in the frontend, ensuring full compatibility.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
import pytz
|
||||
from croniter import CroniterBadCronError
|
||||
|
||||
from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
|
||||
|
||||
|
||||
class TestEnhancedCronSyntax(unittest.TestCase):
|
||||
"""Test enhanced cron syntax in calculate_next_run_at."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test with fixed time."""
|
||||
# Monday, January 15, 2024, 10:00 AM UTC
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_month_abbreviations(self):
|
||||
"""Test month abbreviations (JAN, FEB, etc.)."""
|
||||
test_cases = [
|
||||
("0 12 1 JAN *", 1), # January
|
||||
("0 12 1 FEB *", 2), # February
|
||||
("0 12 1 MAR *", 3), # March
|
||||
("0 12 1 APR *", 4), # April
|
||||
("0 12 1 MAY *", 5), # May
|
||||
("0 12 1 JUN *", 6), # June
|
||||
("0 12 1 JUL *", 7), # July
|
||||
("0 12 1 AUG *", 8), # August
|
||||
("0 12 1 SEP *", 9), # September
|
||||
("0 12 1 OCT *", 10), # October
|
||||
("0 12 1 NOV *", 11), # November
|
||||
("0 12 1 DEC *", 12), # December
|
||||
]
|
||||
|
||||
for expr, expected_month in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None, f"Failed to parse: {expr}"
|
||||
assert result.month == expected_month
|
||||
assert result.day == 1
|
||||
assert result.hour == 12
|
||||
assert result.minute == 0
|
||||
|
||||
def test_weekday_abbreviations(self):
|
||||
"""Test weekday abbreviations (SUN, MON, etc.)."""
|
||||
test_cases = [
|
||||
("0 9 * * SUN", 6), # Sunday (weekday() = 6)
|
||||
("0 9 * * MON", 0), # Monday (weekday() = 0)
|
||||
("0 9 * * TUE", 1), # Tuesday
|
||||
("0 9 * * WED", 2), # Wednesday
|
||||
("0 9 * * THU", 3), # Thursday
|
||||
("0 9 * * FRI", 4), # Friday
|
||||
("0 9 * * SAT", 5), # Saturday
|
||||
]
|
||||
|
||||
for expr, expected_weekday in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None, f"Failed to parse: {expr}"
|
||||
assert result.weekday() == expected_weekday
|
||||
assert result.hour == 9
|
||||
assert result.minute == 0
|
||||
|
||||
def test_sunday_dual_representation(self):
|
||||
"""Test Sunday as both 0 and 7."""
|
||||
base_time = datetime(2024, 1, 14, 10, 0, 0, tzinfo=UTC) # Sunday
|
||||
|
||||
# Both should give the same next Sunday
|
||||
result_0 = calculate_next_run_at("0 10 * * 0", "UTC", base_time)
|
||||
result_7 = calculate_next_run_at("0 10 * * 7", "UTC", base_time)
|
||||
result_SUN = calculate_next_run_at("0 10 * * SUN", "UTC", base_time)
|
||||
|
||||
assert result_0 is not None
|
||||
assert result_7 is not None
|
||||
assert result_SUN is not None
|
||||
|
||||
# All should be Sundays
|
||||
assert result_0.weekday() == 6 # Sunday = 6 in weekday()
|
||||
assert result_7.weekday() == 6
|
||||
assert result_SUN.weekday() == 6
|
||||
|
||||
# Times should be identical
|
||||
assert result_0 == result_7
|
||||
assert result_0 == result_SUN
|
||||
|
||||
def test_predefined_expressions(self):
|
||||
"""Test predefined expressions (@daily, @weekly, etc.)."""
|
||||
test_cases = [
|
||||
("@yearly", lambda dt: dt.month == 1 and dt.day == 1 and dt.hour == 0 and dt.minute == 0),
|
||||
("@annually", lambda dt: dt.month == 1 and dt.day == 1 and dt.hour == 0 and dt.minute == 0),
|
||||
("@monthly", lambda dt: dt.day == 1 and dt.hour == 0 and dt.minute == 0),
|
||||
("@weekly", lambda dt: dt.weekday() == 6 and dt.hour == 0 and dt.minute == 0), # Sunday
|
||||
("@daily", lambda dt: dt.hour == 0 and dt.minute == 0),
|
||||
("@midnight", lambda dt: dt.hour == 0 and dt.minute == 0),
|
||||
("@hourly", lambda dt: dt.minute == 0),
|
||||
]
|
||||
|
||||
for expr, validator in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None, f"Failed to parse: {expr}"
|
||||
assert validator(result), f"Validator failed for {expr}: {result}"
|
||||
|
||||
def test_question_mark_wildcard(self):
|
||||
"""Test ? wildcard character."""
|
||||
# ? in day position with specific weekday
|
||||
result_question = calculate_next_run_at("0 9 ? * 1", "UTC", self.base_time) # Monday
|
||||
result_star = calculate_next_run_at("0 9 * * 1", "UTC", self.base_time) # Monday
|
||||
|
||||
assert result_question is not None
|
||||
assert result_star is not None
|
||||
|
||||
# Both should return Mondays at 9:00
|
||||
assert result_question.weekday() == 0 # Monday
|
||||
assert result_star.weekday() == 0
|
||||
assert result_question.hour == 9
|
||||
assert result_star.hour == 9
|
||||
|
||||
# Results should be identical
|
||||
assert result_question == result_star
|
||||
|
||||
def test_last_day_of_month(self):
|
||||
"""Test 'L' for last day of month."""
|
||||
expr = "0 12 L * *" # Last day of month at noon
|
||||
|
||||
# Test for February (28 days in 2024 - not a leap year check)
|
||||
feb_base = datetime(2024, 2, 15, 10, 0, 0, tzinfo=UTC)
|
||||
result = calculate_next_run_at(expr, "UTC", feb_base)
|
||||
assert result is not None
|
||||
assert result.month == 2
|
||||
assert result.day == 29 # 2024 is a leap year
|
||||
assert result.hour == 12
|
||||
|
||||
def test_range_with_abbreviations(self):
|
||||
"""Test ranges using abbreviations."""
|
||||
test_cases = [
|
||||
"0 9 * * MON-FRI", # Weekday range
|
||||
"0 12 * JAN-MAR *", # Q1 months
|
||||
"0 15 * APR-JUN *", # Q2 months
|
||||
]
|
||||
|
||||
for expr in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None, f"Failed to parse range expression: {expr}"
|
||||
assert result > self.base_time
|
||||
|
||||
def test_list_with_abbreviations(self):
|
||||
"""Test lists using abbreviations."""
|
||||
test_cases = [
|
||||
("0 9 * * SUN,WED,FRI", [6, 2, 4]), # Specific weekdays
|
||||
("0 12 1 JAN,JUN,DEC *", [1, 6, 12]), # Specific months
|
||||
]
|
||||
|
||||
for expr, expected_values in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None, f"Failed to parse list expression: {expr}"
|
||||
|
||||
if "* *" in expr: # Weekday test
|
||||
assert result.weekday() in expected_values
|
||||
else: # Month test
|
||||
assert result.month in expected_values
|
||||
|
||||
def test_mixed_syntax(self):
|
||||
"""Test mixed traditional and enhanced syntax."""
|
||||
test_cases = [
|
||||
"30 14 15 JAN,JUN,DEC *", # Numbers + month abbreviations
|
||||
"0 9 * JAN-MAR MON-FRI", # Month range + weekday range
|
||||
"45 8 1,15 * MON", # Numbers + weekday abbreviation
|
||||
]
|
||||
|
||||
for expr in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None, f"Failed to parse mixed syntax: {expr}"
|
||||
assert result > self.base_time
|
||||
|
||||
def test_complex_enhanced_expressions(self):
|
||||
"""Test complex expressions with multiple enhanced features."""
|
||||
# Note: Some of these might not be supported by croniter, that's OK
|
||||
complex_expressions = [
|
||||
"0 9 L JAN *", # Last day of January
|
||||
"30 14 * * FRI#1", # First Friday of month (if supported)
|
||||
"0 12 15 JAN-DEC/3 *", # 15th of every 3rd month (quarterly)
|
||||
]
|
||||
|
||||
for expr in complex_expressions:
|
||||
with self.subTest(expr=expr):
|
||||
try:
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
if result: # If supported, should return valid result
|
||||
assert result > self.base_time
|
||||
except Exception:
|
||||
# Some complex expressions might not be supported - that's acceptable
|
||||
pass
|
||||
|
||||
|
||||
class TestTimezoneHandlingEnhanced(unittest.TestCase):
|
||||
"""Test timezone handling with enhanced syntax."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test with fixed time."""
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_enhanced_syntax_with_timezones(self):
|
||||
"""Test enhanced syntax works correctly across timezones."""
|
||||
timezones = ["UTC", "America/New_York", "Asia/Tokyo", "Europe/London"]
|
||||
expression = "0 12 * * MON" # Monday at noon
|
||||
|
||||
for timezone in timezones:
|
||||
with self.subTest(timezone=timezone):
|
||||
result = calculate_next_run_at(expression, timezone, self.base_time)
|
||||
assert result is not None
|
||||
|
||||
# Convert to local timezone to verify it's Monday at noon
|
||||
tz = pytz.timezone(timezone)
|
||||
local_time = result.astimezone(tz)
|
||||
assert local_time.weekday() == 0 # Monday
|
||||
assert local_time.hour == 12
|
||||
assert local_time.minute == 0
|
||||
|
||||
def test_predefined_expressions_with_timezones(self):
|
||||
"""Test predefined expressions work with different timezones."""
|
||||
expression = "@daily"
|
||||
timezones = ["UTC", "America/New_York", "Asia/Tokyo"]
|
||||
|
||||
for timezone in timezones:
|
||||
with self.subTest(timezone=timezone):
|
||||
result = calculate_next_run_at(expression, timezone, self.base_time)
|
||||
assert result is not None
|
||||
|
||||
# Should be midnight in the specified timezone
|
||||
tz = pytz.timezone(timezone)
|
||||
local_time = result.astimezone(tz)
|
||||
assert local_time.hour == 0
|
||||
assert local_time.minute == 0
|
||||
|
||||
def test_dst_with_enhanced_syntax(self):
|
||||
"""Test DST handling with enhanced syntax."""
|
||||
# DST spring forward date in 2024
|
||||
dst_base = datetime(2024, 3, 8, 10, 0, 0, tzinfo=UTC)
|
||||
expression = "0 2 * * SUN" # Sunday at 2 AM (problematic during DST)
|
||||
timezone = "America/New_York"
|
||||
|
||||
result = calculate_next_run_at(expression, timezone, dst_base)
|
||||
assert result is not None
|
||||
|
||||
# Should handle DST transition gracefully
|
||||
tz = pytz.timezone(timezone)
|
||||
local_time = result.astimezone(tz)
|
||||
assert local_time.weekday() == 6 # Sunday
|
||||
|
||||
# During DST spring forward, 2 AM might become 3 AM
|
||||
assert local_time.hour in [2, 3]
|
||||
|
||||
|
||||
class TestErrorHandlingEnhanced(unittest.TestCase):
|
||||
"""Test error handling for enhanced syntax."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test with fixed time."""
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_invalid_enhanced_syntax(self):
|
||||
"""Test that invalid enhanced syntax raises appropriate errors."""
|
||||
invalid_expressions = [
|
||||
"0 12 * JANUARY *", # Full month name
|
||||
"0 12 * * MONDAY", # Full day name
|
||||
"0 12 32 JAN *", # Invalid day with valid month
|
||||
"0 12 * * MON-SUN-FRI", # Invalid range syntax
|
||||
"0 12 * JAN- *", # Incomplete range
|
||||
"0 12 * * ,MON", # Invalid list syntax
|
||||
"@INVALID", # Invalid predefined
|
||||
]
|
||||
|
||||
for expr in invalid_expressions:
|
||||
with self.subTest(expr=expr):
|
||||
with pytest.raises((CroniterBadCronError, ValueError)):
|
||||
calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
|
||||
def test_boundary_values_with_enhanced_syntax(self):
|
||||
"""Test boundary values work with enhanced syntax."""
|
||||
# Valid boundary expressions
|
||||
valid_expressions = [
|
||||
"0 0 1 JAN *", # Minimum: January 1st midnight
|
||||
"59 23 31 DEC *", # Maximum: December 31st 23:59
|
||||
"0 12 29 FEB *", # Leap year boundary
|
||||
]
|
||||
|
||||
for expr in valid_expressions:
|
||||
with self.subTest(expr=expr):
|
||||
try:
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
if result: # Some dates might not occur soon
|
||||
assert result > self.base_time
|
||||
except Exception as e:
|
||||
# Some boundary cases might be complex to calculate
|
||||
self.fail(f"Valid boundary expression failed: {expr} - {e}")
|
||||
|
||||
|
||||
class TestPerformanceEnhanced(unittest.TestCase):
|
||||
"""Test performance with enhanced syntax."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test with fixed time."""
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_complex_expression_performance(self):
|
||||
"""Test that complex enhanced expressions parse within reasonable time."""
|
||||
import time
|
||||
|
||||
complex_expressions = [
|
||||
"*/5 9-17 * * MON-FRI", # Every 5 min, weekdays, business hours
|
||||
"0 9 * JAN-MAR MON-FRI", # Q1 weekdays at 9 AM
|
||||
"30 14 1,15 * * ", # 1st and 15th at 14:30
|
||||
"0 12 ? * SUN", # Sundays at noon with ?
|
||||
"@daily", # Predefined expression
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for expr in complex_expressions:
|
||||
with self.subTest(expr=expr):
|
||||
try:
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None
|
||||
except Exception:
|
||||
# Some expressions might not be supported - acceptable
|
||||
pass
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = (end_time - start_time) * 1000 # milliseconds
|
||||
|
||||
# Should be fast (less than 100ms for all expressions)
|
||||
assert execution_time < 100, "Enhanced expressions should parse quickly"
|
||||
|
||||
def test_multiple_calculations_performance(self):
|
||||
"""Test performance when calculating multiple next times."""
|
||||
import time
|
||||
|
||||
expression = "0 9 * * MON-FRI" # Weekdays at 9 AM
|
||||
iterations = 20
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
current_time = self.base_time
|
||||
for _ in range(iterations):
|
||||
result = calculate_next_run_at(expression, "UTC", current_time)
|
||||
assert result is not None
|
||||
current_time = result + timedelta(seconds=1) # Move forward slightly
|
||||
|
||||
end_time = time.time()
|
||||
total_time = (end_time - start_time) * 1000 # milliseconds
|
||||
avg_time = total_time / iterations
|
||||
|
||||
# Average should be very fast (less than 5ms per calculation)
|
||||
assert avg_time < 5, f"Average calculation time too slow: {avg_time}ms"
|
||||
|
||||
|
||||
class TestRegressionEnhanced(unittest.TestCase):
|
||||
"""Regression tests to ensure enhanced syntax doesn't break existing functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test with fixed time."""
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_traditional_syntax_still_works(self):
|
||||
"""Ensure traditional cron syntax continues to work."""
|
||||
traditional_expressions = [
|
||||
"15 10 1 * *", # Monthly 1st at 10:15
|
||||
"0 0 * * 0", # Weekly Sunday midnight
|
||||
"*/5 * * * *", # Every 5 minutes
|
||||
"0 9-17 * * 1-5", # Business hours weekdays
|
||||
"30 14 * * 1", # Monday 14:30
|
||||
"0 0 1,15 * *", # 1st and 15th midnight
|
||||
]
|
||||
|
||||
for expr in traditional_expressions:
|
||||
with self.subTest(expr=expr):
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None, f"Traditional expression failed: {expr}"
|
||||
assert result > self.base_time
|
||||
|
||||
def test_convert_12h_to_24h_unchanged(self):
|
||||
"""Ensure convert_12h_to_24h function is unchanged."""
|
||||
test_cases = [
|
||||
("12:00 AM", (0, 0)), # Midnight
|
||||
("12:00 PM", (12, 0)), # Noon
|
||||
("1:30 AM", (1, 30)), # Early morning
|
||||
("11:45 PM", (23, 45)), # Late evening
|
||||
("6:15 AM", (6, 15)), # Morning
|
||||
("3:30 PM", (15, 30)), # Afternoon
|
||||
]
|
||||
|
||||
for time_str, expected in test_cases:
|
||||
with self.subTest(time_str=time_str):
|
||||
result = convert_12h_to_24h(time_str)
|
||||
assert result == expected, f"12h conversion failed: {time_str}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
53
dify/api/tests/unit_tests/libs/test_sendgrid_client.py
Normal file
53
dify/api/tests/unit_tests/libs/test_sendgrid_client.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from python_http_client.exceptions import UnauthorizedError
|
||||
|
||||
from libs.sendgrid import SendGridClient
|
||||
|
||||
|
||||
def _mail(to: str = "user@example.com") -> dict:
|
||||
return {"to": to, "subject": "Hi", "html": "<b>Hi</b>"}
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_success(mock_client_cls: MagicMock):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
# nested attribute access: client.mail.send.post
|
||||
mock_client.client.mail.send.post.return_value = MagicMock(status_code=202, body=b"", headers={})
|
||||
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
sg.send(_mail())
|
||||
|
||||
mock_client_cls.assert_called_once()
|
||||
mock_client.client.mail.send.post.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_missing_to_raises(mock_client_cls: MagicMock):
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
with pytest.raises(ValueError):
|
||||
sg.send(_mail(to=""))
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_auth_errors_reraise(mock_client_cls: MagicMock):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.client.mail.send.post.side_effect = UnauthorizedError(401, "Unauthorized", b"{}", {})
|
||||
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
with pytest.raises(UnauthorizedError):
|
||||
sg.send(_mail())
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_timeout_reraise(mock_client_cls: MagicMock):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.client.mail.send.post.side_effect = TimeoutError("timeout")
|
||||
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
with pytest.raises(TimeoutError):
|
||||
sg.send(_mail())
|
||||
100
dify/api/tests/unit_tests/libs/test_smtp_client.py
Normal file
100
dify/api/tests/unit_tests/libs/test_smtp_client.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.smtp import SMTPClient
|
||||
|
||||
|
||||
def _mail() -> dict:
|
||||
return {"to": "user@example.com", "subject": "Hi", "html": "<b>Hi</b>"}
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_plain_success(mock_smtp_cls: MagicMock):
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
|
||||
client.send(_mail())
|
||||
|
||||
mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10)
|
||||
mock_smtp.sendmail.assert_called_once()
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock):
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(
|
||||
server="smtp.example.com",
|
||||
port=587,
|
||||
username="user",
|
||||
password="pass",
|
||||
_from="noreply@example.com",
|
||||
use_tls=True,
|
||||
opportunistic_tls=True,
|
||||
)
|
||||
client.send(_mail())
|
||||
|
||||
mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10)
|
||||
assert mock_smtp.ehlo.call_count == 2
|
||||
mock_smtp.starttls.assert_called_once()
|
||||
mock_smtp.login.assert_called_once_with("user", "pass")
|
||||
mock_smtp.sendmail.assert_called_once()
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP_SSL")
|
||||
def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock):
|
||||
# Cover SMTP_SSL branch and TimeoutError handling
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp.sendmail.side_effect = TimeoutError("timeout")
|
||||
mock_smtp_ssl_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(
|
||||
server="smtp.example.com",
|
||||
port=465,
|
||||
username="",
|
||||
password="",
|
||||
_from="noreply@example.com",
|
||||
use_tls=True,
|
||||
opportunistic_tls=False,
|
||||
)
|
||||
with pytest.raises(TimeoutError):
|
||||
client.send(_mail())
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock):
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp.sendmail.side_effect = RuntimeError("oops")
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
|
||||
with pytest.raises(RuntimeError):
|
||||
client.send(_mail())
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_smtplib_exception_in_login(mock_smtp_cls: MagicMock):
|
||||
# Ensure we hit the specific SMTPException except branch
|
||||
import smtplib
|
||||
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp.login.side_effect = smtplib.SMTPException("login-fail")
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(
|
||||
server="smtp.example.com",
|
||||
port=25,
|
||||
username="user", # non-empty to trigger login
|
||||
password="pass",
|
||||
_from="noreply@example.com",
|
||||
)
|
||||
with pytest.raises(smtplib.SMTPException):
|
||||
client.send(_mail())
|
||||
mock_smtp.quit.assert_called_once()
|
||||
91
dify/api/tests/unit_tests/libs/test_time_parser.py
Normal file
91
dify/api/tests/unit_tests/libs/test_time_parser.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Unit tests for time parser utility."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from libs.time_parser import get_time_threshold, parse_time_duration
|
||||
|
||||
|
||||
class TestParseTimeDuration:
|
||||
"""Test parse_time_duration function."""
|
||||
|
||||
def test_parse_days(self):
|
||||
"""Test parsing days."""
|
||||
result = parse_time_duration("7d")
|
||||
assert result == timedelta(days=7)
|
||||
|
||||
def test_parse_hours(self):
|
||||
"""Test parsing hours."""
|
||||
result = parse_time_duration("4h")
|
||||
assert result == timedelta(hours=4)
|
||||
|
||||
def test_parse_minutes(self):
|
||||
"""Test parsing minutes."""
|
||||
result = parse_time_duration("30m")
|
||||
assert result == timedelta(minutes=30)
|
||||
|
||||
def test_parse_seconds(self):
|
||||
"""Test parsing seconds."""
|
||||
result = parse_time_duration("30s")
|
||||
assert result == timedelta(seconds=30)
|
||||
|
||||
def test_parse_uppercase(self):
|
||||
"""Test parsing uppercase units."""
|
||||
result = parse_time_duration("7D")
|
||||
assert result == timedelta(days=7)
|
||||
|
||||
def test_parse_invalid_format(self):
|
||||
"""Test parsing invalid format."""
|
||||
result = parse_time_duration("7days")
|
||||
assert result is None
|
||||
|
||||
result = parse_time_duration("abc")
|
||||
assert result is None
|
||||
|
||||
result = parse_time_duration("7")
|
||||
assert result is None
|
||||
|
||||
def test_parse_empty_string(self):
|
||||
"""Test parsing empty string."""
|
||||
result = parse_time_duration("")
|
||||
assert result is None
|
||||
|
||||
def test_parse_none(self):
|
||||
"""Test parsing None."""
|
||||
result = parse_time_duration(None)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetTimeThreshold:
|
||||
"""Test get_time_threshold function."""
|
||||
|
||||
def test_get_threshold_days(self):
|
||||
"""Test getting threshold for days."""
|
||||
before = datetime.now(UTC)
|
||||
result = get_time_threshold("7d")
|
||||
after = datetime.now(UTC)
|
||||
|
||||
assert result is not None
|
||||
# Result should be approximately 7 days ago
|
||||
expected = before - timedelta(days=7)
|
||||
# Allow 1 second tolerance for test execution time
|
||||
assert abs((result - expected).total_seconds()) < 1
|
||||
|
||||
def test_get_threshold_hours(self):
|
||||
"""Test getting threshold for hours."""
|
||||
before = datetime.now(UTC)
|
||||
result = get_time_threshold("4h")
|
||||
after = datetime.now(UTC)
|
||||
|
||||
assert result is not None
|
||||
expected = before - timedelta(hours=4)
|
||||
assert abs((result - expected).total_seconds()) < 1
|
||||
|
||||
def test_get_threshold_invalid(self):
|
||||
"""Test getting threshold with invalid duration."""
|
||||
result = get_time_threshold("invalid")
|
||||
assert result is None
|
||||
|
||||
def test_get_threshold_none(self):
|
||||
"""Test getting threshold with None."""
|
||||
result = get_time_threshold(None)
|
||||
assert result is None
|
||||
62
dify/api/tests/unit_tests/libs/test_token.py
Normal file
62
dify/api/tests/unit_tests/libs/test_token.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from werkzeug.wrappers import Response
|
||||
|
||||
from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_WEBAPP_ACCESS_TOKEN
|
||||
from libs import token
|
||||
from libs.token import extract_access_token, extract_webapp_access_token, set_csrf_token_to_cookie
|
||||
|
||||
|
||||
class MockRequest:
|
||||
def __init__(self, headers: dict[str, str], cookies: dict[str, str], args: dict[str, str]):
|
||||
self.headers: dict[str, str] = headers
|
||||
self.cookies: dict[str, str] = cookies
|
||||
self.args: dict[str, str] = args
|
||||
|
||||
|
||||
def test_extract_access_token():
|
||||
def _mock_request(headers: dict[str, str], cookies: dict[str, str], args: dict[str, str]):
|
||||
return MockRequest(headers, cookies, args)
|
||||
|
||||
test_cases = [
|
||||
(_mock_request({"Authorization": "Bearer 123"}, {}, {}), "123", "123"),
|
||||
(_mock_request({}, {COOKIE_NAME_ACCESS_TOKEN: "123"}, {}), "123", None),
|
||||
(_mock_request({}, {}, {}), None, None),
|
||||
(_mock_request({"Authorization": "Bearer_aaa 123"}, {}, {}), None, None),
|
||||
(_mock_request({}, {COOKIE_NAME_WEBAPP_ACCESS_TOKEN: "123"}, {}), None, "123"),
|
||||
]
|
||||
for request, expected_console, expected_webapp in test_cases:
|
||||
assert extract_access_token(request) == expected_console # pyright: ignore[reportArgumentType]
|
||||
assert extract_webapp_access_token(request) == expected_webapp # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
def test_real_cookie_name_uses_host_prefix_without_domain(monkeypatch):
|
||||
monkeypatch.setattr(token.dify_config, "CONSOLE_WEB_URL", "https://console.example.com", raising=False)
|
||||
monkeypatch.setattr(token.dify_config, "CONSOLE_API_URL", "https://api.example.com", raising=False)
|
||||
monkeypatch.setattr(token.dify_config, "COOKIE_DOMAIN", "", raising=False)
|
||||
|
||||
assert token._real_cookie_name("csrf_token") == "__Host-csrf_token"
|
||||
|
||||
|
||||
def test_real_cookie_name_without_host_prefix_when_domain_present(monkeypatch):
|
||||
monkeypatch.setattr(token.dify_config, "CONSOLE_WEB_URL", "https://console.example.com", raising=False)
|
||||
monkeypatch.setattr(token.dify_config, "CONSOLE_API_URL", "https://api.example.com", raising=False)
|
||||
monkeypatch.setattr(token.dify_config, "COOKIE_DOMAIN", ".example.com", raising=False)
|
||||
|
||||
assert token._real_cookie_name("csrf_token") == "csrf_token"
|
||||
|
||||
|
||||
def test_set_csrf_cookie_includes_domain_when_configured(monkeypatch):
|
||||
monkeypatch.setattr(token.dify_config, "CONSOLE_WEB_URL", "https://console.example.com", raising=False)
|
||||
monkeypatch.setattr(token.dify_config, "CONSOLE_API_URL", "https://api.example.com", raising=False)
|
||||
monkeypatch.setattr(token.dify_config, "COOKIE_DOMAIN", ".example.com", raising=False)
|
||||
|
||||
response = Response()
|
||||
request = MagicMock()
|
||||
|
||||
set_csrf_token_to_cookie(request, response, "abc123")
|
||||
|
||||
cookies = response.headers.getlist("Set-Cookie")
|
||||
assert any("csrf_token=abc123" in c for c in cookies)
|
||||
assert any("Domain=example.com" in c for c in cookies)
|
||||
assert all("__Host-" not in c for c in cookies)
|
||||
351
dify/api/tests/unit_tests/libs/test_uuid_utils.py
Normal file
351
dify/api/tests/unit_tests/libs/test_uuid_utils.py
Normal file
@@ -0,0 +1,351 @@
|
||||
import struct
|
||||
import time
|
||||
import uuid
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from hypothesis import given
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from libs.uuid_utils import _create_uuidv7_bytes, uuidv7, uuidv7_boundary, uuidv7_timestamp
|
||||
|
||||
|
||||
# Tests for private helper function _create_uuidv7_bytes
|
||||
def test_create_uuidv7_bytes_basic_structure():
|
||||
"""Test basic byte structure creation."""
|
||||
timestamp_ms = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
|
||||
random_bytes = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x11\x22"
|
||||
|
||||
result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
|
||||
|
||||
# Should be exactly 16 bytes
|
||||
assert len(result) == 16
|
||||
assert isinstance(result, bytes)
|
||||
|
||||
# Create UUID from bytes to verify it's valid
|
||||
uuid_obj = uuid.UUID(bytes=result)
|
||||
assert uuid_obj.version == 7
|
||||
|
||||
|
||||
def test_create_uuidv7_bytes_timestamp_encoding():
|
||||
"""Test timestamp is correctly encoded in first 48 bits."""
|
||||
timestamp_ms = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
|
||||
random_bytes = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
|
||||
result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
|
||||
|
||||
# Extract timestamp from first 6 bytes
|
||||
timestamp_bytes = b"\x00\x00" + result[0:6]
|
||||
extracted_timestamp = struct.unpack(">Q", timestamp_bytes)[0]
|
||||
|
||||
assert extracted_timestamp == timestamp_ms
|
||||
|
||||
|
||||
def test_create_uuidv7_bytes_version_bits():
|
||||
"""Test version bits are set to 7."""
|
||||
timestamp_ms = 1609459200000
|
||||
random_bytes = b"\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00" # Set first 2 bytes to all 1s
|
||||
|
||||
result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
|
||||
|
||||
# Extract version from bytes 6-7
|
||||
version_and_rand_a = struct.unpack(">H", result[6:8])[0]
|
||||
version = (version_and_rand_a >> 12) & 0x0F
|
||||
|
||||
assert version == 7
|
||||
|
||||
|
||||
def test_create_uuidv7_bytes_variant_bits():
|
||||
"""Test variant bits are set correctly."""
|
||||
timestamp_ms = 1609459200000
|
||||
random_bytes = b"\x00\x00\xff\x00\x00\x00\x00\x00\x00\x00" # Set byte 8 to all 1s
|
||||
|
||||
result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
|
||||
|
||||
# Check variant bits in byte 8 (should be 10xxxxxx)
|
||||
variant_byte = result[8]
|
||||
variant_bits = (variant_byte >> 6) & 0b11
|
||||
|
||||
assert variant_bits == 0b10 # Should be binary 10
|
||||
|
||||
|
||||
def test_create_uuidv7_bytes_random_data():
|
||||
"""Test random bytes are placed correctly."""
|
||||
timestamp_ms = 1609459200000
|
||||
random_bytes = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x11\x22"
|
||||
|
||||
result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
|
||||
|
||||
# Check random data A (12 bits from bytes 6-7, excluding version)
|
||||
version_and_rand_a = struct.unpack(">H", result[6:8])[0]
|
||||
rand_a = version_and_rand_a & 0x0FFF
|
||||
expected_rand_a = struct.unpack(">H", random_bytes[0:2])[0] & 0x0FFF
|
||||
assert rand_a == expected_rand_a
|
||||
|
||||
# Check random data B (bytes 8-15, with variant bits preserved)
|
||||
# Byte 8 should have variant bits set but preserve lower 6 bits
|
||||
expected_byte_8 = (random_bytes[2] & 0x3F) | 0x80
|
||||
assert result[8] == expected_byte_8
|
||||
|
||||
# Bytes 9-15 should match random_bytes[3:10]
|
||||
assert result[9:16] == random_bytes[3:10]
|
||||
|
||||
|
||||
def test_create_uuidv7_bytes_zero_random():
|
||||
"""Test with zero random bytes (boundary case)."""
|
||||
timestamp_ms = 1609459200000
|
||||
zero_random_bytes = b"\x00" * 10
|
||||
|
||||
result = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes)
|
||||
|
||||
# Should still be valid UUIDv7
|
||||
uuid_obj = uuid.UUID(bytes=result)
|
||||
assert uuid_obj.version == 7
|
||||
|
||||
# Version bits should be 0x7000
|
||||
version_and_rand_a = struct.unpack(">H", result[6:8])[0]
|
||||
assert version_and_rand_a == 0x7000
|
||||
|
||||
# Variant byte should be 0x80 (variant bits + zero random bits)
|
||||
assert result[8] == 0x80
|
||||
|
||||
# Remaining bytes should be zero
|
||||
assert result[9:16] == b"\x00" * 7
|
||||
|
||||
|
||||
def test_uuidv7_basic_generation():
|
||||
"""Test basic UUID generation produces valid UUIDv7."""
|
||||
result = uuidv7()
|
||||
|
||||
# Should be a UUID object
|
||||
assert isinstance(result, uuid.UUID)
|
||||
|
||||
# Should be version 7
|
||||
assert result.version == 7
|
||||
|
||||
# Should have correct variant (RFC 4122 variant)
|
||||
# Variant bits should be 10xxxxxx (0x80-0xBF range)
|
||||
variant_byte = result.bytes[8]
|
||||
assert (variant_byte >> 6) == 0b10
|
||||
|
||||
|
||||
def test_uuidv7_with_custom_timestamp():
|
||||
"""Test UUID generation with custom timestamp."""
|
||||
custom_timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
|
||||
result = uuidv7(custom_timestamp)
|
||||
|
||||
assert isinstance(result, uuid.UUID)
|
||||
assert result.version == 7
|
||||
|
||||
# Extract and verify timestamp
|
||||
extracted_timestamp = uuidv7_timestamp(result)
|
||||
assert isinstance(extracted_timestamp, int)
|
||||
assert extracted_timestamp == custom_timestamp # Exact match for integer milliseconds
|
||||
|
||||
|
||||
def test_uuidv7_with_none_timestamp(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test UUID generation with None timestamp uses current time."""
|
||||
mock_time = 1609459200
|
||||
mock_time_func = mock.Mock(return_value=mock_time)
|
||||
monkeypatch.setattr("time.time", mock_time_func)
|
||||
result = uuidv7(None)
|
||||
|
||||
assert isinstance(result, uuid.UUID)
|
||||
assert result.version == 7
|
||||
|
||||
# Should use the mocked current time (converted to milliseconds)
|
||||
assert mock_time_func.called
|
||||
extracted_timestamp = uuidv7_timestamp(result)
|
||||
assert extracted_timestamp == mock_time * 1000 # 1609459200.0 * 1000
|
||||
|
||||
|
||||
def test_uuidv7_time_ordering():
|
||||
"""Test that sequential UUIDs have increasing timestamps."""
|
||||
# Generate UUIDs with incrementing timestamps (in milliseconds)
|
||||
timestamp1 = 1609459200000 # 2021-01-01 00:00:00 UTC
|
||||
timestamp2 = 1609459201000 # 2021-01-01 00:00:01 UTC
|
||||
timestamp3 = 1609459202000 # 2021-01-01 00:00:02 UTC
|
||||
|
||||
uuid1 = uuidv7(timestamp1)
|
||||
uuid2 = uuidv7(timestamp2)
|
||||
uuid3 = uuidv7(timestamp3)
|
||||
|
||||
# Extract timestamps
|
||||
ts1 = uuidv7_timestamp(uuid1)
|
||||
ts2 = uuidv7_timestamp(uuid2)
|
||||
ts3 = uuidv7_timestamp(uuid3)
|
||||
|
||||
# Should be in ascending order
|
||||
assert ts1 < ts2 < ts3
|
||||
|
||||
# UUIDs should be lexicographically ordered by their string representation
|
||||
# due to time-ordering property of UUIDv7
|
||||
uuid_strings = [str(uuid1), str(uuid2), str(uuid3)]
|
||||
assert uuid_strings == sorted(uuid_strings)
|
||||
|
||||
|
||||
def test_uuidv7_uniqueness():
|
||||
"""Test that multiple calls generate different UUIDs."""
|
||||
# Generate multiple UUIDs with the same timestamp (in milliseconds)
|
||||
timestamp = 1609459200000
|
||||
uuids = [uuidv7(timestamp) for _ in range(100)]
|
||||
|
||||
# All should be unique despite same timestamp (due to random bits)
|
||||
assert len(set(uuids)) == 100
|
||||
|
||||
# All should have the same extracted timestamp
|
||||
for uuid_obj in uuids:
|
||||
extracted_ts = uuidv7_timestamp(uuid_obj)
|
||||
assert extracted_ts == timestamp
|
||||
|
||||
|
||||
def test_uuidv7_timestamp_error_handling_wrong_version():
|
||||
"""Test error handling for non-UUIDv7 inputs."""
|
||||
|
||||
uuid_v4 = uuid.uuid4()
|
||||
with pytest.raises(ValueError) as exc_ctx:
|
||||
uuidv7_timestamp(uuid_v4)
|
||||
assert "Expected UUIDv7 (version 7)" in str(exc_ctx.value)
|
||||
assert f"got version {uuid_v4.version}" in str(exc_ctx.value)
|
||||
|
||||
|
||||
@given(st.integers(max_value=2**48 - 1, min_value=0))
|
||||
def test_uuidv7_timestamp_round_trip(timestamp_ms):
|
||||
# Generate UUID with timestamp
|
||||
uuid_obj = uuidv7(timestamp_ms)
|
||||
|
||||
# Extract timestamp back
|
||||
extracted_timestamp = uuidv7_timestamp(uuid_obj)
|
||||
|
||||
# Should match exactly for integer millisecond timestamps
|
||||
assert extracted_timestamp == timestamp_ms
|
||||
|
||||
|
||||
def test_uuidv7_timestamp_edge_cases():
|
||||
"""Test timestamp extraction with edge case values."""
|
||||
# Test with very small timestamp
|
||||
small_timestamp = 1 # 1ms after epoch
|
||||
uuid_small = uuidv7(small_timestamp)
|
||||
extracted_small = uuidv7_timestamp(uuid_small)
|
||||
assert extracted_small == small_timestamp
|
||||
|
||||
# Test with large timestamp (year 2038+)
|
||||
large_timestamp = 2147483647000 # 2038-01-19 03:14:07 UTC in milliseconds
|
||||
uuid_large = uuidv7(large_timestamp)
|
||||
extracted_large = uuidv7_timestamp(uuid_large)
|
||||
assert extracted_large == large_timestamp
|
||||
|
||||
|
||||
def test_uuidv7_boundary_basic_generation():
|
||||
"""Test basic boundary UUID generation with a known timestamp."""
|
||||
timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
|
||||
result = uuidv7_boundary(timestamp)
|
||||
|
||||
# Should be a UUID object
|
||||
assert isinstance(result, uuid.UUID)
|
||||
|
||||
# Should be version 7
|
||||
assert result.version == 7
|
||||
|
||||
# Should have correct variant (RFC 4122 variant)
|
||||
# Variant bits should be 10xxxxxx (0x80-0xBF range)
|
||||
variant_byte = result.bytes[8]
|
||||
assert (variant_byte >> 6) == 0b10
|
||||
|
||||
|
||||
def test_uuidv7_boundary_timestamp_extraction():
|
||||
"""Test that boundary UUID timestamp can be extracted correctly."""
|
||||
timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
|
||||
boundary_uuid = uuidv7_boundary(timestamp)
|
||||
|
||||
# Extract timestamp using existing function
|
||||
extracted_timestamp = uuidv7_timestamp(boundary_uuid)
|
||||
|
||||
# Should match exactly
|
||||
assert extracted_timestamp == timestamp
|
||||
|
||||
|
||||
def test_uuidv7_boundary_deterministic():
|
||||
"""Test that boundary UUIDs are deterministic for same timestamp."""
|
||||
timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
|
||||
|
||||
# Generate multiple boundary UUIDs with same timestamp
|
||||
uuid1 = uuidv7_boundary(timestamp)
|
||||
uuid2 = uuidv7_boundary(timestamp)
|
||||
uuid3 = uuidv7_boundary(timestamp)
|
||||
|
||||
# Should all be identical
|
||||
assert uuid1 == uuid2 == uuid3
|
||||
assert str(uuid1) == str(uuid2) == str(uuid3)
|
||||
|
||||
|
||||
def test_uuidv7_boundary_is_minimum():
|
||||
"""Test that boundary UUID is lexicographically smaller than regular UUIDs."""
|
||||
timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
|
||||
|
||||
# Generate boundary UUID
|
||||
boundary_uuid = uuidv7_boundary(timestamp)
|
||||
|
||||
# Generate multiple regular UUIDs with same timestamp
|
||||
regular_uuids = [uuidv7(timestamp) for _ in range(50)]
|
||||
|
||||
# Boundary UUID should be lexicographically smaller than all regular UUIDs
|
||||
boundary_str = str(boundary_uuid)
|
||||
for regular_uuid in regular_uuids:
|
||||
regular_str = str(regular_uuid)
|
||||
assert boundary_str < regular_str, f"Boundary {boundary_str} should be < regular {regular_str}"
|
||||
|
||||
# Also test with bytes comparison
|
||||
boundary_bytes = boundary_uuid.bytes
|
||||
for regular_uuid in regular_uuids:
|
||||
regular_bytes = regular_uuid.bytes
|
||||
assert boundary_bytes < regular_bytes
|
||||
|
||||
|
||||
def test_uuidv7_boundary_different_timestamps():
|
||||
"""Test that boundary UUIDs with different timestamps are ordered correctly."""
|
||||
timestamp1 = 1609459200000 # 2021-01-01 00:00:00 UTC
|
||||
timestamp2 = 1609459201000 # 2021-01-01 00:00:01 UTC
|
||||
timestamp3 = 1609459202000 # 2021-01-01 00:00:02 UTC
|
||||
|
||||
uuid1 = uuidv7_boundary(timestamp1)
|
||||
uuid2 = uuidv7_boundary(timestamp2)
|
||||
uuid3 = uuidv7_boundary(timestamp3)
|
||||
|
||||
# Extract timestamps to verify
|
||||
ts1 = uuidv7_timestamp(uuid1)
|
||||
ts2 = uuidv7_timestamp(uuid2)
|
||||
ts3 = uuidv7_timestamp(uuid3)
|
||||
|
||||
# Should be in ascending order
|
||||
assert ts1 < ts2 < ts3
|
||||
|
||||
# UUIDs should be lexicographically ordered
|
||||
uuid_strings = [str(uuid1), str(uuid2), str(uuid3)]
|
||||
assert uuid_strings == sorted(uuid_strings)
|
||||
|
||||
# Bytes should also be ordered
|
||||
assert uuid1.bytes < uuid2.bytes < uuid3.bytes
|
||||
|
||||
|
||||
def test_uuidv7_boundary_edge_cases():
|
||||
"""Test boundary UUID generation with edge case timestamp values."""
|
||||
# Test with timestamp 0 (Unix epoch)
|
||||
epoch_uuid = uuidv7_boundary(0)
|
||||
assert isinstance(epoch_uuid, uuid.UUID)
|
||||
assert epoch_uuid.version == 7
|
||||
assert uuidv7_timestamp(epoch_uuid) == 0
|
||||
|
||||
# Test with very large timestamp values
|
||||
large_timestamp = 2147483647000 # 2038-01-19 03:14:07 UTC in milliseconds
|
||||
large_uuid = uuidv7_boundary(large_timestamp)
|
||||
assert isinstance(large_uuid, uuid.UUID)
|
||||
assert large_uuid.version == 7
|
||||
assert uuidv7_timestamp(large_uuid) == large_timestamp
|
||||
|
||||
# Test with current time
|
||||
current_time = int(time.time() * 1000)
|
||||
current_uuid = uuidv7_boundary(current_time)
|
||||
assert isinstance(current_uuid, uuid.UUID)
|
||||
assert current_uuid.version == 7
|
||||
assert uuidv7_timestamp(current_uuid) == current_time
|
||||
29
dify/api/tests/unit_tests/libs/test_yarl.py
Normal file
29
dify/api/tests/unit_tests/libs/test_yarl.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import pytest
|
||||
from yarl import URL
|
||||
|
||||
|
||||
def test_yarl_urls():
|
||||
expected_1 = "https://dify.ai/api"
|
||||
assert str(URL("https://dify.ai") / "api") == expected_1
|
||||
assert str(URL("https://dify.ai/") / "api") == expected_1
|
||||
|
||||
expected_2 = "http://dify.ai:12345/api"
|
||||
assert str(URL("http://dify.ai:12345") / "api") == expected_2
|
||||
assert str(URL("http://dify.ai:12345/") / "api") == expected_2
|
||||
|
||||
expected_3 = "https://dify.ai/api/v1"
|
||||
assert str(URL("https://dify.ai") / "api" / "v1") == expected_3
|
||||
assert str(URL("https://dify.ai") / "api/v1") == expected_3
|
||||
assert str(URL("https://dify.ai/") / "api/v1") == expected_3
|
||||
assert str(URL("https://dify.ai/api") / "v1") == expected_3
|
||||
assert str(URL("https://dify.ai/api/") / "v1") == expected_3
|
||||
|
||||
expected_4 = "api"
|
||||
assert str(URL("") / "api") == expected_4
|
||||
|
||||
expected_5 = "/api"
|
||||
assert str(URL("/") / "api") == expected_5
|
||||
|
||||
with pytest.raises(ValueError) as e1:
|
||||
str(URL("https://dify.ai") / "/api")
|
||||
assert str(e1.value) == "Appending path '/api' starting from slash is forbidden"
|
||||
Reference in New Issue
Block a user