This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,381 @@
"""
Enhanced cron syntax compatibility tests for croniter backend.
This test suite mirrors the frontend cron-parser tests to ensure
complete compatibility between frontend and backend cron processing.
"""
import unittest
from datetime import UTC, datetime, timedelta
import pytest
import pytz
from croniter import CroniterBadCronError
from libs.schedule_utils import calculate_next_run_at
class TestCronCompatibility(unittest.TestCase):
"""Test enhanced cron syntax compatibility with frontend."""
def setUp(self):
"""Set up test environment with fixed time."""
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
def test_enhanced_dayofweek_syntax(self):
"""Test enhanced day-of-week syntax compatibility."""
test_cases = [
("0 9 * * 7", 0), # Sunday as 7
("0 9 * * 0", 0), # Sunday as 0
("0 9 * * MON", 1), # Monday abbreviation
("0 9 * * TUE", 2), # Tuesday abbreviation
("0 9 * * WED", 3), # Wednesday abbreviation
("0 9 * * THU", 4), # Thursday abbreviation
("0 9 * * FRI", 5), # Friday abbreviation
("0 9 * * SAT", 6), # Saturday abbreviation
("0 9 * * SUN", 0), # Sunday abbreviation
]
for expr, expected_weekday in test_cases:
with self.subTest(expr=expr):
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
assert next_time is not None
assert (next_time.weekday() + 1 if next_time.weekday() < 6 else 0) == expected_weekday
assert next_time.hour == 9
assert next_time.minute == 0
def test_enhanced_month_syntax(self):
"""Test enhanced month syntax compatibility."""
test_cases = [
("0 9 1 JAN *", 1), # January abbreviation
("0 9 1 FEB *", 2), # February abbreviation
("0 9 1 MAR *", 3), # March abbreviation
("0 9 1 APR *", 4), # April abbreviation
("0 9 1 MAY *", 5), # May abbreviation
("0 9 1 JUN *", 6), # June abbreviation
("0 9 1 JUL *", 7), # July abbreviation
("0 9 1 AUG *", 8), # August abbreviation
("0 9 1 SEP *", 9), # September abbreviation
("0 9 1 OCT *", 10), # October abbreviation
("0 9 1 NOV *", 11), # November abbreviation
("0 9 1 DEC *", 12), # December abbreviation
]
for expr, expected_month in test_cases:
with self.subTest(expr=expr):
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
assert next_time is not None
assert next_time.month == expected_month
assert next_time.day == 1
assert next_time.hour == 9
def test_predefined_expressions(self):
"""Test predefined cron expressions compatibility."""
test_cases = [
("@yearly", lambda dt: dt.month == 1 and dt.day == 1 and dt.hour == 0),
("@annually", lambda dt: dt.month == 1 and dt.day == 1 and dt.hour == 0),
("@monthly", lambda dt: dt.day == 1 and dt.hour == 0),
("@weekly", lambda dt: dt.weekday() == 6 and dt.hour == 0), # Sunday = 6 in weekday()
("@daily", lambda dt: dt.hour == 0 and dt.minute == 0),
("@midnight", lambda dt: dt.hour == 0 and dt.minute == 0),
("@hourly", lambda dt: dt.minute == 0),
]
for expr, validator in test_cases:
with self.subTest(expr=expr):
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
assert next_time is not None
assert validator(next_time), f"Validator failed for {expr}: {next_time}"
def test_special_characters(self):
"""Test special characters in cron expressions."""
test_cases = [
"0 9 ? * 1", # ? wildcard
"0 12 * * 7", # Sunday as 7
"0 15 L * *", # Last day of month
]
for expr in test_cases:
with self.subTest(expr=expr):
try:
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
assert next_time is not None
assert next_time > self.base_time
except Exception as e:
self.fail(f"Expression '{expr}' should be valid but raised: {e}")
def test_range_and_list_syntax(self):
"""Test range and list syntax with abbreviations."""
test_cases = [
"0 9 * * MON-FRI", # Weekday range with abbreviations
"0 9 * JAN-MAR *", # Month range with abbreviations
"0 9 * * SUN,WED,FRI", # Weekday list with abbreviations
"0 9 1 JAN,JUN,DEC *", # Month list with abbreviations
]
for expr in test_cases:
with self.subTest(expr=expr):
try:
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
assert next_time is not None
assert next_time > self.base_time
except Exception as e:
self.fail(f"Expression '{expr}' should be valid but raised: {e}")
def test_invalid_enhanced_syntax(self):
"""Test that invalid enhanced syntax is properly rejected."""
invalid_expressions = [
"0 12 * JANUARY *", # Full month name (not supported)
"0 12 * * MONDAY", # Full day name (not supported)
"0 12 32 JAN *", # Invalid day with valid month
"15 10 1 * 8", # Invalid day of week
"15 10 1 INVALID *", # Invalid month abbreviation
"15 10 1 * INVALID", # Invalid day abbreviation
"@invalid", # Invalid predefined expression
]
for expr in invalid_expressions:
with self.subTest(expr=expr):
with pytest.raises((CroniterBadCronError, ValueError)):
calculate_next_run_at(expr, "UTC", self.base_time)
def test_edge_cases_with_enhanced_syntax(self):
"""Test edge cases with enhanced syntax."""
test_cases = [
("0 0 29 FEB *", lambda dt: dt.month == 2 and dt.day == 29), # Feb 29 with month abbreviation
]
for expr, validator in test_cases:
with self.subTest(expr=expr):
try:
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
if next_time: # Some combinations might not occur soon
assert validator(next_time), f"Validator failed for {expr}: {next_time}"
except (CroniterBadCronError, ValueError):
# Some edge cases might be valid but not have upcoming occurrences
pass
# Test complex expressions that have specific constraints
complex_expr = "59 23 31 DEC SAT" # December 31st at 23:59 on Saturday
try:
next_time = calculate_next_run_at(complex_expr, "UTC", self.base_time)
if next_time:
# The next occurrence might not be exactly Dec 31 if it's not a Saturday
# Just verify it's a valid result
assert next_time is not None
assert next_time.hour == 23
assert next_time.minute == 59
except Exception:
# Complex date constraints might not have near-future occurrences
pass
class TestTimezoneCompatibility(unittest.TestCase):
"""Test timezone compatibility between frontend and backend."""
def setUp(self):
"""Set up test environment."""
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
def test_timezone_consistency(self):
"""Test that calculations are consistent across different timezones."""
timezones = [
"UTC",
"America/New_York",
"Europe/London",
"Asia/Tokyo",
"Asia/Kolkata",
"Australia/Sydney",
]
expression = "0 12 * * *" # Daily at noon
for timezone in timezones:
with self.subTest(timezone=timezone):
next_time = calculate_next_run_at(expression, timezone, self.base_time)
assert next_time is not None
# Convert back to the target timezone to verify it's noon
tz = pytz.timezone(timezone)
local_time = next_time.astimezone(tz)
assert local_time.hour == 12
assert local_time.minute == 0
def test_dst_handling(self):
"""Test DST boundary handling."""
# Test around DST spring forward (March 2024)
dst_base = datetime(2024, 3, 8, 10, 0, 0, tzinfo=UTC)
expression = "0 2 * * *" # 2 AM daily (problematic during DST)
timezone = "America/New_York"
try:
next_time = calculate_next_run_at(expression, timezone, dst_base)
assert next_time is not None
# During DST spring forward, 2 AM becomes 3 AM - both are acceptable
tz = pytz.timezone(timezone)
local_time = next_time.astimezone(tz)
assert local_time.hour in [2, 3] # Either 2 AM or 3 AM is acceptable
except Exception as e:
self.fail(f"DST handling failed: {e}")
def test_half_hour_timezones(self):
"""Test timezones with half-hour offsets."""
timezones_with_offsets = [
("Asia/Kolkata", 17, 30), # UTC+5:30 -> 12:00 UTC = 17:30 IST
("Australia/Adelaide", 22, 30), # UTC+10:30 -> 12:00 UTC = 22:30 ACDT (summer time)
]
expression = "0 12 * * *" # Noon UTC
for timezone, expected_hour, expected_minute in timezones_with_offsets:
with self.subTest(timezone=timezone):
try:
next_time = calculate_next_run_at(expression, timezone, self.base_time)
assert next_time is not None
tz = pytz.timezone(timezone)
local_time = next_time.astimezone(tz)
assert local_time.hour == expected_hour
assert local_time.minute == expected_minute
except Exception:
# Some complex timezone calculations might vary
pass
def test_invalid_timezone_handling(self):
"""Test handling of invalid timezones."""
expression = "0 12 * * *"
invalid_timezone = "Invalid/Timezone"
with pytest.raises((ValueError, Exception)): # Should raise an exception
calculate_next_run_at(expression, invalid_timezone, self.base_time)
class TestFrontendBackendIntegration(unittest.TestCase):
"""Test integration patterns that mirror frontend usage."""
def setUp(self):
"""Set up test environment."""
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
def test_execution_time_calculator_pattern(self):
"""Test the pattern used by execution-time-calculator.ts."""
# This mirrors the exact usage from execution-time-calculator.ts:47
test_data = {
"cron_expression": "30 14 * * 1-5", # 2:30 PM weekdays
"timezone": "America/New_York",
}
# Get next 5 execution times (like the frontend does)
execution_times = []
current_base = self.base_time
for _ in range(5):
next_time = calculate_next_run_at(test_data["cron_expression"], test_data["timezone"], current_base)
assert next_time is not None
execution_times.append(next_time)
current_base = next_time + timedelta(seconds=1) # Move slightly forward
assert len(execution_times) == 5
# Validate each execution time
for exec_time in execution_times:
# Convert to local timezone
tz = pytz.timezone(test_data["timezone"])
local_time = exec_time.astimezone(tz)
# Should be weekdays (1-5)
assert local_time.weekday() in [0, 1, 2, 3, 4] # Mon-Fri in Python weekday
# Should be 2:30 PM in local time
assert local_time.hour == 14
assert local_time.minute == 30
assert local_time.second == 0
def test_schedule_service_integration(self):
"""Test integration with ScheduleService patterns."""
from core.workflow.nodes.trigger_schedule.entities import VisualConfig
from services.trigger.schedule_service import ScheduleService
# Test enhanced syntax through visual config conversion
visual_configs = [
# Test with month abbreviations
{
"frequency": "monthly",
"config": VisualConfig(time="9:00 AM", monthly_days=[1]),
"expected_cron": "0 9 1 * *",
},
# Test with weekday abbreviations
{
"frequency": "weekly",
"config": VisualConfig(time="2:30 PM", weekdays=["mon", "wed", "fri"]),
"expected_cron": "30 14 * * 1,3,5",
},
]
for test_case in visual_configs:
with self.subTest(frequency=test_case["frequency"]):
cron_expr = ScheduleService.visual_to_cron(test_case["frequency"], test_case["config"])
assert cron_expr == test_case["expected_cron"]
# Verify the generated cron expression is valid
next_time = calculate_next_run_at(cron_expr, "UTC", self.base_time)
assert next_time is not None
def test_error_handling_consistency(self):
"""Test that error handling matches frontend expectations."""
invalid_expressions = [
"60 10 1 * *", # Invalid minute
"15 25 1 * *", # Invalid hour
"15 10 32 * *", # Invalid day
"15 10 1 13 *", # Invalid month
"15 10 1", # Too few fields
"15 10 1 * * *", # 6 fields (not supported in frontend)
"0 15 10 1 * * *", # 7 fields (not supported in frontend)
"invalid expression", # Completely invalid
]
for expr in invalid_expressions:
with self.subTest(expr=repr(expr)):
with pytest.raises((CroniterBadCronError, ValueError, Exception)):
calculate_next_run_at(expr, "UTC", self.base_time)
# Note: Empty/whitespace expressions are not tested here as they are
# not expected in normal usage due to database constraints (nullable=False)
def test_performance_requirements(self):
"""Test that complex expressions parse within reasonable time."""
import time
complex_expressions = [
"*/5 9-17 * * 1-5", # Every 5 minutes, weekdays, business hours
"0 */2 1,15 * *", # Every 2 hours on 1st and 15th
"30 14 * * 1,3,5", # Mon, Wed, Fri at 14:30
"15,45 8-18 * * 1-5", # 15 and 45 minutes past hour, weekdays
"0 9 * JAN-MAR MON-FRI", # Enhanced syntax: Q1 weekdays at 9 AM
"0 12 ? * SUN", # Enhanced syntax: Sundays at noon with ?
]
start_time = time.time()
for expr in complex_expressions:
with self.subTest(expr=expr):
try:
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
assert next_time is not None
except CroniterBadCronError:
# Some enhanced syntax might not be supported, that's OK
pass
end_time = time.time()
execution_time = (end_time - start_time) * 1000 # Convert to milliseconds
# Should complete within reasonable time (less than 150ms like frontend)
assert execution_time < 150, "Complex expressions should parse quickly"
if __name__ == "__main__":
# Import timedelta for the test
from datetime import timedelta
unittest.main()

View File

@@ -0,0 +1,68 @@
"""Unit tests for custom input types."""
import pytest
from libs.custom_inputs import time_duration
class TestTimeDuration:
"""Test time_duration input validator."""
def test_valid_days(self):
"""Test valid days format."""
result = time_duration("7d")
assert result == "7d"
def test_valid_hours(self):
"""Test valid hours format."""
result = time_duration("4h")
assert result == "4h"
def test_valid_minutes(self):
"""Test valid minutes format."""
result = time_duration("30m")
assert result == "30m"
def test_valid_seconds(self):
"""Test valid seconds format."""
result = time_duration("30s")
assert result == "30s"
def test_uppercase_conversion(self):
"""Test uppercase units are converted to lowercase."""
result = time_duration("7D")
assert result == "7d"
result = time_duration("4H")
assert result == "4h"
def test_invalid_format_no_unit(self):
"""Test invalid format without unit."""
with pytest.raises(ValueError, match="Invalid time duration format"):
time_duration("7")
def test_invalid_format_wrong_unit(self):
"""Test invalid format with wrong unit."""
with pytest.raises(ValueError, match="Invalid time duration format"):
time_duration("7days")
with pytest.raises(ValueError, match="Invalid time duration format"):
time_duration("7x")
def test_invalid_format_no_number(self):
"""Test invalid format without number."""
with pytest.raises(ValueError, match="Invalid time duration format"):
time_duration("d")
with pytest.raises(ValueError, match="Invalid time duration format"):
time_duration("abc")
def test_empty_string(self):
"""Test empty string."""
with pytest.raises(ValueError, match="Time duration cannot be empty"):
time_duration("")
def test_none(self):
"""Test None value."""
with pytest.raises(ValueError, match="Time duration cannot be empty"):
time_duration(None)

View File

@@ -0,0 +1,268 @@
import datetime
from unittest.mock import patch
import pytest
import pytz
from libs.datetime_utils import naive_utc_now, parse_time_range
def test_naive_utc_now(monkeypatch: pytest.MonkeyPatch):
tz_aware_utc_now = datetime.datetime.now(tz=datetime.UTC)
def _now_func(tz: datetime.timezone | None) -> datetime.datetime:
return tz_aware_utc_now.astimezone(tz)
monkeypatch.setattr("libs.datetime_utils._now_func", _now_func)
naive_datetime = naive_utc_now()
assert naive_datetime.tzinfo is None
assert naive_datetime.date() == tz_aware_utc_now.date()
naive_time = naive_datetime.time()
utc_time = tz_aware_utc_now.time()
assert naive_time == utc_time
class TestParseTimeRange:
"""Test cases for parse_time_range function."""
def test_parse_time_range_basic(self):
"""Test basic time range parsing."""
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "UTC")
assert start is not None
assert end is not None
assert start < end
assert start.tzinfo == pytz.UTC
assert end.tzinfo == pytz.UTC
def test_parse_time_range_start_only(self):
"""Test parsing with only start time."""
start, end = parse_time_range("2024-01-01 10:00", None, "UTC")
assert start is not None
assert end is None
assert start.tzinfo == pytz.UTC
def test_parse_time_range_end_only(self):
"""Test parsing with only end time."""
start, end = parse_time_range(None, "2024-01-01 18:00", "UTC")
assert start is None
assert end is not None
assert end.tzinfo == pytz.UTC
def test_parse_time_range_both_none(self):
"""Test parsing with both times None."""
start, end = parse_time_range(None, None, "UTC")
assert start is None
assert end is None
def test_parse_time_range_different_timezones(self):
"""Test parsing with different timezones."""
# Test with US/Eastern timezone
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "US/Eastern")
assert start is not None
assert end is not None
assert start.tzinfo == pytz.UTC
assert end.tzinfo == pytz.UTC
# Verify the times are correctly converted to UTC
assert start.hour == 15 # 10 AM EST = 3 PM UTC (in January)
assert end.hour == 23 # 6 PM EST = 11 PM UTC (in January)
def test_parse_time_range_invalid_start_format(self):
"""Test parsing with invalid start time format."""
with pytest.raises(ValueError, match="time data.*does not match format"):
parse_time_range("invalid-date", "2024-01-01 18:00", "UTC")
def test_parse_time_range_invalid_end_format(self):
"""Test parsing with invalid end time format."""
with pytest.raises(ValueError, match="time data.*does not match format"):
parse_time_range("2024-01-01 10:00", "invalid-date", "UTC")
def test_parse_time_range_invalid_timezone(self):
"""Test parsing with invalid timezone."""
with pytest.raises(pytz.exceptions.UnknownTimeZoneError):
parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "Invalid/Timezone")
def test_parse_time_range_start_after_end(self):
"""Test parsing with start time after end time."""
with pytest.raises(ValueError, match="start must be earlier than or equal to end"):
parse_time_range("2024-01-01 18:00", "2024-01-01 10:00", "UTC")
def test_parse_time_range_start_equals_end(self):
"""Test parsing with start time equal to end time."""
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 10:00", "UTC")
assert start is not None
assert end is not None
assert start == end
def test_parse_time_range_dst_ambiguous_time(self):
"""Test parsing during DST ambiguous time (fall back)."""
# This test simulates DST fall back where 2:30 AM occurs twice
with patch("pytz.timezone") as mock_timezone:
# Mock timezone that raises AmbiguousTimeError
mock_tz = mock_timezone.return_value
# Create a mock datetime object for the return value
mock_dt = datetime.datetime(2024, 1, 1, 10, 0, 0)
mock_utc_dt = mock_dt.replace(tzinfo=pytz.UTC)
# Create a proper mock for the localized datetime
from unittest.mock import MagicMock
mock_localized_dt = MagicMock()
mock_localized_dt.astimezone.return_value = mock_utc_dt
# Set up side effects: first call raises exception, second call succeeds
mock_tz.localize.side_effect = [
pytz.AmbiguousTimeError("Ambiguous time"), # First call for start
mock_localized_dt, # Second call for start (with is_dst=False)
pytz.AmbiguousTimeError("Ambiguous time"), # First call for end
mock_localized_dt, # Second call for end (with is_dst=False)
]
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "US/Eastern")
# Should use is_dst=False for ambiguous times
assert mock_tz.localize.call_count == 4 # 2 calls per time (first fails, second succeeds)
assert start is not None
assert end is not None
def test_parse_time_range_dst_nonexistent_time(self):
"""Test parsing during DST nonexistent time (spring forward)."""
with patch("pytz.timezone") as mock_timezone:
# Mock timezone that raises NonExistentTimeError
mock_tz = mock_timezone.return_value
# Create a mock datetime object for the return value
mock_dt = datetime.datetime(2024, 1, 1, 10, 0, 0)
mock_utc_dt = mock_dt.replace(tzinfo=pytz.UTC)
# Create a proper mock for the localized datetime
from unittest.mock import MagicMock
mock_localized_dt = MagicMock()
mock_localized_dt.astimezone.return_value = mock_utc_dt
# Set up side effects: first call raises exception, second call succeeds
mock_tz.localize.side_effect = [
pytz.NonExistentTimeError("Non-existent time"), # First call for start
mock_localized_dt, # Second call for start (with adjusted time)
pytz.NonExistentTimeError("Non-existent time"), # First call for end
mock_localized_dt, # Second call for end (with adjusted time)
]
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "US/Eastern")
# Should adjust time forward by 1 hour for nonexistent times
assert mock_tz.localize.call_count == 4 # 2 calls per time (first fails, second succeeds)
assert start is not None
assert end is not None
def test_parse_time_range_edge_cases(self):
"""Test edge cases for time parsing."""
# Test with midnight times
start, end = parse_time_range("2024-01-01 00:00", "2024-01-01 23:59", "UTC")
assert start is not None
assert end is not None
assert start.hour == 0
assert start.minute == 0
assert end.hour == 23
assert end.minute == 59
def test_parse_time_range_different_dates(self):
"""Test parsing with different dates."""
start, end = parse_time_range("2024-01-01 10:00", "2024-01-02 10:00", "UTC")
assert start is not None
assert end is not None
assert start.date() != end.date()
assert (end - start).days == 1
def test_parse_time_range_seconds_handling(self):
"""Test that seconds are properly set to 0."""
start, end = parse_time_range("2024-01-01 10:30", "2024-01-01 18:45", "UTC")
assert start is not None
assert end is not None
assert start.second == 0
assert end.second == 0
def test_parse_time_range_timezone_conversion_accuracy(self):
"""Test accurate timezone conversion."""
# Test with a known timezone conversion
start, end = parse_time_range("2024-01-01 12:00", "2024-01-01 12:00", "Asia/Tokyo")
assert start is not None
assert end is not None
assert start.tzinfo == pytz.UTC
assert end.tzinfo == pytz.UTC
# Tokyo is UTC+9, so 12:00 JST = 03:00 UTC
assert start.hour == 3
assert end.hour == 3
def test_parse_time_range_summer_time(self):
"""Test parsing during summer time (DST)."""
# Test with US/Eastern during summer (EDT = UTC-4)
start, end = parse_time_range("2024-07-01 12:00", "2024-07-01 12:00", "US/Eastern")
assert start is not None
assert end is not None
assert start.tzinfo == pytz.UTC
assert end.tzinfo == pytz.UTC
# 12:00 EDT = 16:00 UTC
assert start.hour == 16
assert end.hour == 16
def test_parse_time_range_winter_time(self):
"""Test parsing during winter time (standard time)."""
# Test with US/Eastern during winter (EST = UTC-5)
start, end = parse_time_range("2024-01-01 12:00", "2024-01-01 12:00", "US/Eastern")
assert start is not None
assert end is not None
assert start.tzinfo == pytz.UTC
assert end.tzinfo == pytz.UTC
# 12:00 EST = 17:00 UTC
assert start.hour == 17
assert end.hour == 17
def test_parse_time_range_empty_strings(self):
"""Test parsing with empty strings."""
# Empty strings are treated as None, so they should not raise errors
start, end = parse_time_range("", "2024-01-01 18:00", "UTC")
assert start is None
assert end is not None
start, end = parse_time_range("2024-01-01 10:00", "", "UTC")
assert start is not None
assert end is None
def test_parse_time_range_malformed_datetime(self):
"""Test parsing with malformed datetime strings."""
with pytest.raises(ValueError, match="time data.*does not match format"):
parse_time_range("2024-13-01 10:00", "2024-01-01 18:00", "UTC")
with pytest.raises(ValueError, match="time data.*does not match format"):
parse_time_range("2024-01-01 10:00", "2024-01-32 18:00", "UTC")
def test_parse_time_range_very_long_time_range(self):
"""Test parsing with very long time range."""
start, end = parse_time_range("2020-01-01 00:00", "2030-12-31 23:59", "UTC")
assert start is not None
assert end is not None
assert start < end
assert (end - start).days > 3000 # More than 8 years
def test_parse_time_range_negative_timezone(self):
"""Test parsing with negative timezone offset."""
start, end = parse_time_range("2024-01-01 12:00", "2024-01-01 12:00", "America/New_York")
assert start is not None
assert end is not None
assert start.tzinfo == pytz.UTC
assert end.tzinfo == pytz.UTC

View File

@@ -0,0 +1,21 @@
import pytest
from libs.helper import email
def test_email_with_valid_email():
assert email("test@example.com") == "test@example.com"
assert email("TEST12345@example.com") == "TEST12345@example.com"
assert email("test+test@example.com") == "test+test@example.com"
assert email("!#$%&'*+-/=?^_{|}~`@example.com") == "!#$%&'*+-/=?^_{|}~`@example.com"
def test_email_with_invalid_email():
with pytest.raises(ValueError, match="invalid_email is not a valid email."):
email("invalid_email")
with pytest.raises(ValueError, match="@example.com is not a valid email."):
email("@example.com")
with pytest.raises(ValueError, match="()@example.com is not a valid email."):
email("()@example.com")

View File

@@ -0,0 +1,576 @@
"""
Unit tests for EmailI18nService
Tests the email internationalization service with mocked dependencies
following Domain-Driven Design principles.
"""
from typing import Any
from unittest.mock import MagicMock
import pytest
from libs.email_i18n import (
EmailI18nConfig,
EmailI18nService,
EmailLanguage,
EmailTemplate,
EmailType,
FlaskEmailRenderer,
FlaskMailSender,
create_default_email_config,
get_email_i18n_service,
)
from services.feature_service import BrandingModel
class MockEmailRenderer:
"""Mock implementation of EmailRenderer protocol"""
def __init__(self):
self.rendered_templates: list[tuple[str, dict[str, Any]]] = []
def render_template(self, template_path: str, **context: Any) -> str:
"""Mock render_template that returns a formatted string"""
self.rendered_templates.append((template_path, context))
return f"<html>Rendered {template_path} with {context}</html>"
class MockBrandingService:
"""Mock implementation of BrandingService protocol"""
def __init__(self, enabled: bool = False, application_title: str = "Dify"):
self.enabled = enabled
self.application_title = application_title
def get_branding_config(self) -> BrandingModel:
"""Return mock branding configuration"""
branding_model = MagicMock(spec=BrandingModel)
branding_model.enabled = self.enabled
branding_model.application_title = self.application_title
return branding_model
class MockEmailSender:
"""Mock implementation of EmailSender protocol"""
def __init__(self):
self.sent_emails: list[dict[str, str]] = []
def send_email(self, to: str, subject: str, html_content: str):
"""Mock send_email that records sent emails"""
self.sent_emails.append(
{
"to": to,
"subject": subject,
"html_content": html_content,
}
)
class TestEmailI18nService:
"""Test cases for EmailI18nService"""
@pytest.fixture
def email_config(self) -> EmailI18nConfig:
"""Create test email configuration"""
return EmailI18nConfig(
templates={
EmailType.RESET_PASSWORD: {
EmailLanguage.EN_US: EmailTemplate(
subject="Reset Your {application_title} Password",
template_path="reset_password_en.html",
branded_template_path="branded/reset_password_en.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="重置您的 {application_title} 密码",
template_path="reset_password_zh.html",
branded_template_path="branded/reset_password_zh.html",
),
},
EmailType.INVITE_MEMBER: {
EmailLanguage.EN_US: EmailTemplate(
subject="Join {application_title} Workspace",
template_path="invite_member_en.html",
branded_template_path="branded/invite_member_en.html",
),
},
}
)
@pytest.fixture
def mock_renderer(self) -> MockEmailRenderer:
"""Create mock email renderer"""
return MockEmailRenderer()
@pytest.fixture
def mock_branding_service(self) -> MockBrandingService:
"""Create mock branding service"""
return MockBrandingService()
@pytest.fixture
def mock_sender(self) -> MockEmailSender:
"""Create mock email sender"""
return MockEmailSender()
@pytest.fixture
def email_service(
self,
email_config: EmailI18nConfig,
mock_renderer: MockEmailRenderer,
mock_branding_service: MockBrandingService,
mock_sender: MockEmailSender,
) -> EmailI18nService:
"""Create EmailI18nService with mocked dependencies"""
return EmailI18nService(
config=email_config,
renderer=mock_renderer,
branding_service=mock_branding_service,
sender=mock_sender,
)
def test_send_email_with_english_language(
self,
email_service: EmailI18nService,
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
):
"""Test sending email with English language"""
email_service.send_email(
email_type=EmailType.RESET_PASSWORD,
language_code="en-US",
to="test@example.com",
template_context={"reset_link": "https://example.com/reset"},
)
# Verify renderer was called with correct template
assert len(mock_renderer.rendered_templates) == 1
template_path, context = mock_renderer.rendered_templates[0]
assert template_path == "reset_password_en.html"
assert context["reset_link"] == "https://example.com/reset"
assert context["branding_enabled"] is False
assert context["application_title"] == "Dify"
# Verify email was sent
assert len(mock_sender.sent_emails) == 1
sent_email = mock_sender.sent_emails[0]
assert sent_email["to"] == "test@example.com"
assert sent_email["subject"] == "Reset Your Dify Password"
assert "reset_password_en.html" in sent_email["html_content"]
def test_send_email_with_chinese_language(
self,
email_service: EmailI18nService,
mock_sender: MockEmailSender,
):
"""Test sending email with Chinese language"""
email_service.send_email(
email_type=EmailType.RESET_PASSWORD,
language_code="zh-Hans",
to="test@example.com",
template_context={"reset_link": "https://example.com/reset"},
)
# Verify email was sent with Chinese subject
assert len(mock_sender.sent_emails) == 1
sent_email = mock_sender.sent_emails[0]
assert sent_email["subject"] == "重置您的 Dify 密码"
def test_send_email_with_branding_enabled(
self,
email_config: EmailI18nConfig,
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
):
"""Test sending email with branding enabled"""
# Create branding service with branding enabled
branding_service = MockBrandingService(enabled=True, application_title="MyApp")
email_service = EmailI18nService(
config=email_config,
renderer=mock_renderer,
branding_service=branding_service,
sender=mock_sender,
)
email_service.send_email(
email_type=EmailType.RESET_PASSWORD,
language_code="en-US",
to="test@example.com",
)
# Verify branded template was used
assert len(mock_renderer.rendered_templates) == 1
template_path, context = mock_renderer.rendered_templates[0]
assert template_path == "branded/reset_password_en.html"
assert context["branding_enabled"] is True
assert context["application_title"] == "MyApp"
# Verify subject includes custom application title
assert len(mock_sender.sent_emails) == 1
sent_email = mock_sender.sent_emails[0]
assert sent_email["subject"] == "Reset Your MyApp Password"
def test_send_email_with_language_fallback(
self,
email_service: EmailI18nService,
mock_sender: MockEmailSender,
):
"""Test language fallback to English when requested language not available"""
# Request invite member in Chinese (not configured)
email_service.send_email(
email_type=EmailType.INVITE_MEMBER,
language_code="zh-Hans",
to="test@example.com",
)
# Should fall back to English
assert len(mock_sender.sent_emails) == 1
sent_email = mock_sender.sent_emails[0]
assert sent_email["subject"] == "Join Dify Workspace"
def test_send_email_with_unknown_language_code(
self,
email_service: EmailI18nService,
mock_sender: MockEmailSender,
):
"""Test unknown language code falls back to English"""
email_service.send_email(
email_type=EmailType.RESET_PASSWORD,
language_code="fr-FR", # French not configured
to="test@example.com",
)
# Should use English
assert len(mock_sender.sent_emails) == 1
sent_email = mock_sender.sent_emails[0]
assert sent_email["subject"] == "Reset Your Dify Password"
def test_subject_format_keyerror_fallback_path(
self,
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
):
"""Trigger subject KeyError and cover except branch."""
# Config with subject that references an unknown key (no {application_title} to avoid second format)
config = EmailI18nConfig(
templates={
EmailType.INVITE_MEMBER: {
EmailLanguage.EN_US: EmailTemplate(
subject="Invite: {unknown_placeholder}",
template_path="invite_member_en.html",
branded_template_path="branded/invite_member_en.html",
),
}
}
)
branding_service = MockBrandingService(enabled=False)
service = EmailI18nService(
config=config,
renderer=mock_renderer,
branding_service=branding_service,
sender=mock_sender,
)
# Will raise KeyError on subject.format(**full_context), then hit except branch and skip fallback
service.send_email(
email_type=EmailType.INVITE_MEMBER,
language_code="en-US",
to="test@example.com",
)
assert len(mock_sender.sent_emails) == 1
# Subject is left unformatted due to KeyError fallback path without application_title
assert mock_sender.sent_emails[0]["subject"] == "Invite: {unknown_placeholder}"
def test_send_change_email_old_phase(
self,
email_config: EmailI18nConfig,
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
mock_branding_service: MockBrandingService,
):
"""Test sending change email for old email verification"""
# Add change email templates to config
email_config.templates[EmailType.CHANGE_EMAIL_OLD] = {
EmailLanguage.EN_US: EmailTemplate(
subject="Verify your current email",
template_path="change_email_old_en.html",
branded_template_path="branded/change_email_old_en.html",
),
}
email_service = EmailI18nService(
config=email_config,
renderer=mock_renderer,
branding_service=mock_branding_service,
sender=mock_sender,
)
email_service.send_change_email(
language_code="en-US",
to="old@example.com",
code="123456",
phase="old_email",
)
# Verify correct template and context
assert len(mock_renderer.rendered_templates) == 1
template_path, context = mock_renderer.rendered_templates[0]
assert template_path == "change_email_old_en.html"
assert context["to"] == "old@example.com"
assert context["code"] == "123456"
def test_send_change_email_new_phase(
self,
email_config: EmailI18nConfig,
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
mock_branding_service: MockBrandingService,
):
"""Test sending change email for new email verification"""
# Add change email templates to config
email_config.templates[EmailType.CHANGE_EMAIL_NEW] = {
EmailLanguage.EN_US: EmailTemplate(
subject="Verify your new email",
template_path="change_email_new_en.html",
branded_template_path="branded/change_email_new_en.html",
),
}
email_service = EmailI18nService(
config=email_config,
renderer=mock_renderer,
branding_service=mock_branding_service,
sender=mock_sender,
)
email_service.send_change_email(
language_code="en-US",
to="new@example.com",
code="654321",
phase="new_email",
)
# Verify correct template and context
assert len(mock_renderer.rendered_templates) == 1
template_path, context = mock_renderer.rendered_templates[0]
assert template_path == "change_email_new_en.html"
assert context["to"] == "new@example.com"
assert context["code"] == "654321"
def test_send_change_email_invalid_phase(
self,
email_service: EmailI18nService,
):
"""Test sending change email with invalid phase raises error"""
with pytest.raises(ValueError, match="Invalid phase: invalid_phase"):
email_service.send_change_email(
language_code="en-US",
to="test@example.com",
code="123456",
phase="invalid_phase",
)
def test_send_raw_email_single_recipient(
self,
email_service: EmailI18nService,
mock_sender: MockEmailSender,
):
"""Test sending raw email to single recipient"""
email_service.send_raw_email(
to="test@example.com",
subject="Test Subject",
html_content="<html>Test Content</html>",
)
assert len(mock_sender.sent_emails) == 1
sent_email = mock_sender.sent_emails[0]
assert sent_email["to"] == "test@example.com"
assert sent_email["subject"] == "Test Subject"
assert sent_email["html_content"] == "<html>Test Content</html>"
def test_send_raw_email_multiple_recipients(
self,
email_service: EmailI18nService,
mock_sender: MockEmailSender,
):
"""Test sending raw email to multiple recipients"""
recipients = ["user1@example.com", "user2@example.com", "user3@example.com"]
email_service.send_raw_email(
to=recipients,
subject="Test Subject",
html_content="<html>Test Content</html>",
)
# Should send individual emails to each recipient
assert len(mock_sender.sent_emails) == 3
for i, recipient in enumerate(recipients):
sent_email = mock_sender.sent_emails[i]
assert sent_email["to"] == recipient
assert sent_email["subject"] == "Test Subject"
assert sent_email["html_content"] == "<html>Test Content</html>"
def test_get_template_missing_email_type(
self,
email_config: EmailI18nConfig,
):
"""Test getting template for missing email type raises error"""
with pytest.raises(ValueError, match="No templates configured for email type"):
email_config.get_template(EmailType.EMAIL_CODE_LOGIN, EmailLanguage.EN_US)
def test_get_template_missing_language_and_english(
self,
email_config: EmailI18nConfig,
):
"""Test error when neither requested language nor English fallback exists"""
# Add template without English fallback
email_config.templates[EmailType.EMAIL_CODE_LOGIN] = {
EmailLanguage.ZH_HANS: EmailTemplate(
subject="Test",
template_path="test.html",
branded_template_path="branded/test.html",
),
}
with pytest.raises(ValueError, match="No template found for"):
# Request a language that doesn't exist and no English fallback
email_config.get_template(EmailType.EMAIL_CODE_LOGIN, EmailLanguage.EN_US)
def test_subject_templating_with_variables(
self,
email_config: EmailI18nConfig,
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
mock_branding_service: MockBrandingService,
):
"""Test subject templating with custom variables"""
# Add template with variable in subject
email_config.templates[EmailType.OWNER_TRANSFER_NEW_NOTIFY] = {
EmailLanguage.EN_US: EmailTemplate(
subject="You are now the owner of {WorkspaceName}",
template_path="owner_transfer_en.html",
branded_template_path="branded/owner_transfer_en.html",
),
}
email_service = EmailI18nService(
config=email_config,
renderer=mock_renderer,
branding_service=mock_branding_service,
sender=mock_sender,
)
email_service.send_email(
email_type=EmailType.OWNER_TRANSFER_NEW_NOTIFY,
language_code="en-US",
to="test@example.com",
template_context={"WorkspaceName": "My Workspace"},
)
# Verify subject was templated correctly
assert len(mock_sender.sent_emails) == 1
sent_email = mock_sender.sent_emails[0]
assert sent_email["subject"] == "You are now the owner of My Workspace"
def test_email_language_from_language_code(self):
"""Test EmailLanguage.from_language_code method"""
assert EmailLanguage.from_language_code("zh-Hans") == EmailLanguage.ZH_HANS
assert EmailLanguage.from_language_code("en-US") == EmailLanguage.EN_US
assert EmailLanguage.from_language_code("fr-FR") == EmailLanguage.EN_US # Fallback
assert EmailLanguage.from_language_code("unknown") == EmailLanguage.EN_US # Fallback
class TestEmailI18nIntegration:
"""Integration tests for email i18n components"""
def test_create_default_email_config(self):
"""Test creating default email configuration"""
config = create_default_email_config()
# Verify key email types have at least English template
expected_types = [
EmailType.RESET_PASSWORD,
EmailType.INVITE_MEMBER,
EmailType.EMAIL_CODE_LOGIN,
EmailType.CHANGE_EMAIL_OLD,
EmailType.CHANGE_EMAIL_NEW,
EmailType.OWNER_TRANSFER_CONFIRM,
EmailType.OWNER_TRANSFER_OLD_NOTIFY,
EmailType.OWNER_TRANSFER_NEW_NOTIFY,
EmailType.ACCOUNT_DELETION_SUCCESS,
EmailType.ACCOUNT_DELETION_VERIFICATION,
EmailType.QUEUE_MONITOR_ALERT,
EmailType.DOCUMENT_CLEAN_NOTIFY,
]
for email_type in expected_types:
assert email_type in config.templates
assert EmailLanguage.EN_US in config.templates[email_type]
# Verify some have Chinese translations
assert EmailLanguage.ZH_HANS in config.templates[EmailType.RESET_PASSWORD]
assert EmailLanguage.ZH_HANS in config.templates[EmailType.INVITE_MEMBER]
def test_get_email_i18n_service(self):
"""Test getting global email i18n service instance"""
service1 = get_email_i18n_service()
service2 = get_email_i18n_service()
# Should return the same instance
assert service1 is service2
def test_flask_email_renderer(self):
"""Test FlaskEmailRenderer implementation"""
renderer = FlaskEmailRenderer()
# Should raise TemplateNotFound when template doesn't exist
from jinja2.exceptions import TemplateNotFound
with pytest.raises(TemplateNotFound):
renderer.render_template("test.html", foo="bar")
def test_flask_mail_sender_not_initialized(self):
"""Test FlaskMailSender when mail is not initialized"""
sender = FlaskMailSender()
# Mock mail.is_inited() to return False
import libs.email_i18n
original_mail = libs.email_i18n.mail
mock_mail = MagicMock()
mock_mail.is_inited.return_value = False
libs.email_i18n.mail = mock_mail
try:
# Should not send email when mail is not initialized
sender.send_email("test@example.com", "Subject", "<html>Content</html>")
mock_mail.send.assert_not_called()
finally:
# Restore original mail
libs.email_i18n.mail = original_mail
def test_flask_mail_sender_initialized(self):
"""Test FlaskMailSender when mail is initialized"""
sender = FlaskMailSender()
# Mock mail.is_inited() to return True
import libs.email_i18n
original_mail = libs.email_i18n.mail
mock_mail = MagicMock()
mock_mail.is_inited.return_value = True
libs.email_i18n.mail = mock_mail
try:
# Should send email when mail is initialized
sender.send_email("test@example.com", "Subject", "<html>Content</html>")
mock_mail.send.assert_called_once_with(
to="test@example.com",
subject="Subject",
html="<html>Content</html>",
)
finally:
# Restore original mail
libs.email_i18n.mail = original_mail

View File

@@ -0,0 +1,187 @@
from flask import Blueprint, Flask
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, Unauthorized
from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_CSRF_TOKEN, COOKIE_NAME_REFRESH_TOKEN
from core.errors.error import AppInvokeQuotaExceededError
from libs.exception import BaseHTTPException
from libs.external_api import ExternalApi
def _create_api_app():
app = Flask(__name__)
bp = Blueprint("t", __name__)
api = ExternalApi(bp)
@api.route("/bad-request")
class Bad(Resource):
def get(self):
raise BadRequest("invalid input")
@api.route("/unauth")
class Unauth(Resource):
def get(self):
raise Unauthorized("auth required")
@api.route("/value-error")
class ValErr(Resource):
def get(self):
raise ValueError("boom")
@api.route("/quota")
class Quota(Resource):
def get(self):
raise AppInvokeQuotaExceededError("quota exceeded")
@api.route("/general")
class Gen(Resource):
def get(self):
raise RuntimeError("oops")
# Note: We avoid altering default_mediatype to keep normal error paths
# Special 400 message rewrite
@api.route("/json-empty")
class JsonEmpty(Resource):
def get(self):
e = BadRequest()
# Force the specific message the handler rewrites
e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
raise e
# 400 mapping payload path
@api.route("/param-errors")
class ParamErrors(Resource):
def get(self):
e = BadRequest()
# Coerce a mapping description to trigger param error shaping
e.description = {"field": "is required"}
raise e
app.register_blueprint(bp, url_prefix="/api")
return app
def test_external_api_error_handlers_basic_paths():
app = _create_api_app()
client = app.test_client()
# 400
res = client.get("/api/bad-request")
assert res.status_code == 400
data = res.get_json()
assert data["code"] == "bad_request"
assert data["status"] == 400
# 401
res = client.get("/api/unauth")
assert res.status_code == 401
assert "WWW-Authenticate" in res.headers
# 400 ValueError
res = client.get("/api/value-error")
assert res.status_code == 400
assert res.get_json()["code"] == "invalid_param"
# 500 general
res = client.get("/api/general")
assert res.status_code == 500
assert res.get_json()["status"] == 500
def test_external_api_json_message_and_bad_request_rewrite():
app = _create_api_app()
client = app.test_client()
# JSON empty special rewrite
res = client.get("/api/json-empty")
assert res.status_code == 400
assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty."
def test_external_api_param_mapping_and_quota_and_exc_info_none():
# Force exc_info() to return (None,None,None) only during request
import libs.external_api as ext
orig_exc_info = ext.sys.exc_info
try:
ext.sys.exc_info = lambda: (None, None, None)
app = _create_api_app()
client = app.test_client()
# Param errors mapping payload path
res = client.get("/api/param-errors")
assert res.status_code == 400
data = res.get_json()
assert data["code"] == "invalid_param"
assert data["params"] == "field"
# Quota path — depending on Flask-RESTX internals it may be handled
res = client.get("/api/quota")
assert res.status_code in (400, 429)
finally:
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]
def test_unauthorized_and_force_logout_clears_cookies():
"""Test that UnauthorizedAndForceLogout error clears auth cookies"""
class UnauthorizedAndForceLogout(BaseHTTPException):
error_code = "unauthorized_and_force_logout"
description = "Unauthorized and force logout."
code = 401
app = Flask(__name__)
bp = Blueprint("test", __name__)
api = ExternalApi(bp)
@api.route("/force-logout")
class ForceLogout(Resource): # type: ignore
def get(self): # type: ignore
raise UnauthorizedAndForceLogout()
app.register_blueprint(bp, url_prefix="/api")
client = app.test_client()
# Set cookies first
client.set_cookie(COOKIE_NAME_ACCESS_TOKEN, "test_access_token")
client.set_cookie(COOKIE_NAME_CSRF_TOKEN, "test_csrf_token")
client.set_cookie(COOKIE_NAME_REFRESH_TOKEN, "test_refresh_token")
# Make request that should trigger cookie clearing
res = client.get("/api/force-logout")
# Verify response
assert res.status_code == 401
data = res.get_json()
assert data["code"] == "unauthorized_and_force_logout"
assert data["status"] == 401
assert "WWW-Authenticate" in res.headers
# Verify Set-Cookie headers are present to clear cookies
set_cookie_headers = res.headers.getlist("Set-Cookie")
assert len(set_cookie_headers) == 3, f"Expected 3 Set-Cookie headers, got {len(set_cookie_headers)}"
# Verify each cookie is being cleared (empty value and expired)
cookie_names_found = set()
for cookie_header in set_cookie_headers:
# Check for cookie names
if COOKIE_NAME_ACCESS_TOKEN in cookie_header:
cookie_names_found.add(COOKIE_NAME_ACCESS_TOKEN)
assert '""' in cookie_header or "=" in cookie_header # Empty value
assert "Expires=Thu, 01 Jan 1970" in cookie_header # Expired
elif COOKIE_NAME_CSRF_TOKEN in cookie_header:
cookie_names_found.add(COOKIE_NAME_CSRF_TOKEN)
assert '""' in cookie_header or "=" in cookie_header
assert "Expires=Thu, 01 Jan 1970" in cookie_header
elif COOKIE_NAME_REFRESH_TOKEN in cookie_header:
cookie_names_found.add(COOKIE_NAME_REFRESH_TOKEN)
assert '""' in cookie_header or "=" in cookie_header
assert "Expires=Thu, 01 Jan 1970" in cookie_header
# Verify all three cookies are present
assert len(cookie_names_found) == 3
assert COOKIE_NAME_ACCESS_TOKEN in cookie_names_found
assert COOKIE_NAME_CSRF_TOKEN in cookie_names_found
assert COOKIE_NAME_REFRESH_TOKEN in cookie_names_found

View File

@@ -0,0 +1,55 @@
from pathlib import Path
import pytest
from libs.file_utils import search_file_upwards
def test_search_file_upwards_found_in_parent(tmp_path: Path):
base = tmp_path / "a" / "b" / "c"
base.mkdir(parents=True)
target = tmp_path / "a" / "target.txt"
target.write_text("ok", encoding="utf-8")
found = search_file_upwards(base, "target.txt", max_search_parent_depth=5)
assert found == target
def test_search_file_upwards_found_in_current(tmp_path: Path):
base = tmp_path / "x"
base.mkdir()
target = base / "here.txt"
target.write_text("x", encoding="utf-8")
found = search_file_upwards(base, "here.txt", max_search_parent_depth=1)
assert found == target
def test_search_file_upwards_not_found_raises(tmp_path: Path):
base = tmp_path / "m" / "n"
base.mkdir(parents=True)
with pytest.raises(ValueError) as exc:
search_file_upwards(base, "missing.txt", max_search_parent_depth=3)
# error message should contain file name and base path
msg = str(exc.value)
assert "missing.txt" in msg
assert str(base) in msg
def test_search_file_upwards_root_breaks_and_raises():
# Using filesystem root triggers the 'break' branch (parent == current)
with pytest.raises(ValueError):
search_file_upwards(Path("/"), "__definitely_not_exists__.txt", max_search_parent_depth=1)
def test_search_file_upwards_depth_limit_raises(tmp_path: Path):
base = tmp_path / "a" / "b" / "c"
base.mkdir(parents=True)
target = tmp_path / "a" / "target.txt"
target.write_text("ok", encoding="utf-8")
# The file is 2 levels up from `c` (in `a`), but search depth is only 2.
# The search path is `c` (depth 1) -> `b` (depth 2). The file is in `a` (would need depth 3).
# So, this should not find the file and should raise an error.
with pytest.raises(ValueError):
search_file_upwards(base, "target.txt", max_search_parent_depth=2)

View File

@@ -0,0 +1,123 @@
import contextvars
import threading
import pytest
from flask import Flask
from flask_login import LoginManager, UserMixin, current_user, login_user
from libs.flask_utils import preserve_flask_contexts
class User(UserMixin):
"""Simple User class for testing."""
def __init__(self, id: str):
self.id = id
def get_id(self) -> str:
return self.id
@pytest.fixture
def login_app(app: Flask) -> Flask:
"""Set up a Flask app with flask-login."""
# Set a secret key for the app
app.config["SECRET_KEY"] = "test-secret-key"
login_manager = LoginManager()
login_manager.init_app(app)
@login_manager.user_loader
def load_user(user_id: str) -> User | None:
if user_id == "test_user":
return User("test_user")
return None
return app
@pytest.fixture
def test_user() -> User:
"""Create a test user."""
return User("test_user")
def test_current_user_not_accessible_across_threads(login_app: Flask, test_user: User):
"""
Test that current_user is not accessible in a different thread without preserve_flask_contexts.
This test demonstrates that without the preserve_flask_contexts, we cannot access
current_user in a different thread, even with app_context.
"""
# Log in the user in the main thread
with login_app.test_request_context():
login_user(test_user)
assert current_user.is_authenticated
assert current_user.id == "test_user"
# Store the result of the thread execution
result = {"user_accessible": True, "error": None}
# Define a function to run in a separate thread
def check_user_in_thread():
try:
# Try to access current_user in a different thread with app_context
with login_app.app_context():
# This should fail because current_user is not accessible across threads
# without preserve_flask_contexts
result["user_accessible"] = current_user.is_authenticated
except Exception as e:
result["error"] = str(e)
# Run the function in a separate thread
thread = threading.Thread(target=check_user_in_thread)
thread.start()
thread.join()
# Verify that we got an error or current_user is not authenticated
assert result["error"] is not None or (result["user_accessible"] is not None and not result["user_accessible"])
def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask, test_user: User):
"""
Test that current_user is accessible in a different thread with preserve_flask_contexts.
This test demonstrates that with the preserve_flask_contexts, we can access
current_user in a different thread.
"""
# Log in the user in the main thread
with login_app.test_request_context():
login_user(test_user)
assert current_user.is_authenticated
assert current_user.id == "test_user"
# Save the context variables
context_vars = contextvars.copy_context()
# Store the result of the thread execution
result = {"user_accessible": False, "user_id": None, "error": None}
# Define a function to run in a separate thread
def check_user_in_thread_with_manager():
try:
# Use preserve_flask_contexts to access current_user in a different thread
with preserve_flask_contexts(login_app, context_vars):
from flask_login import current_user
if current_user:
result["user_accessible"] = True
result["user_id"] = current_user.id
else:
result["user_accessible"] = False
except Exception as e:
result["error"] = str(e)
# Run the function in a separate thread
thread = threading.Thread(target=check_user_in_thread_with_manager)
thread.start()
thread.join()
# Verify that current_user is accessible and has the correct ID
assert result["error"] is None
assert result["user_accessible"] is True
assert result["user_id"] == "test_user"

View File

@@ -0,0 +1,65 @@
import pytest
from libs.helper import extract_tenant_id
from models.account import Account
from models.model import EndUser
class TestExtractTenantId:
"""Test cases for the extract_tenant_id utility function."""
def test_extract_tenant_id_from_account_with_tenant(self):
"""Test extracting tenant_id from Account with current_tenant_id."""
# Create a mock Account object
account = Account(name="test", email="test@example.com")
# Mock the current_tenant_id property
account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})()
tenant_id = extract_tenant_id(account)
assert tenant_id == "account-tenant-123"
def test_extract_tenant_id_from_account_without_tenant(self):
"""Test extracting tenant_id from Account without current_tenant_id."""
# Create a mock Account object
account = Account(name="test", email="test@example.com")
account._current_tenant = None
tenant_id = extract_tenant_id(account)
assert tenant_id is None
def test_extract_tenant_id_from_enduser_with_tenant(self):
"""Test extracting tenant_id from EndUser with tenant_id."""
# Create a mock EndUser object
end_user = EndUser()
end_user.tenant_id = "enduser-tenant-456"
tenant_id = extract_tenant_id(end_user)
assert tenant_id == "enduser-tenant-456"
def test_extract_tenant_id_from_enduser_without_tenant(self):
"""Test extracting tenant_id from EndUser without tenant_id."""
# Create a mock EndUser object
end_user = EndUser()
end_user.tenant_id = None
tenant_id = extract_tenant_id(end_user)
assert tenant_id is None
def test_extract_tenant_id_with_invalid_user_type(self):
"""Test extracting tenant_id with invalid user type raises ValueError."""
invalid_user = "not_a_user_object"
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
extract_tenant_id(invalid_user)
def test_extract_tenant_id_with_none_user(self):
"""Test extracting tenant_id with None user raises ValueError."""
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
extract_tenant_id(None)
def test_extract_tenant_id_with_dict_user(self):
"""Test extracting tenant_id with dict user raises ValueError."""
dict_user = {"id": "123", "tenant_id": "456"}
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
extract_tenant_id(dict_user)

View File

@@ -0,0 +1,109 @@
import pytest
from core.llm_generator.output_parser.errors import OutputParserError
from libs.json_in_md_parser import (
parse_and_check_json_markdown,
parse_json_markdown,
)
def test_parse_json_markdown_triple_backticks_json():
src = """
```json
{"a": 1, "b": "x"}
```
"""
assert parse_json_markdown(src) == {"a": 1, "b": "x"}
def test_parse_json_markdown_triple_backticks_generic():
src = """
```
{"k": [1, 2, 3]}
```
"""
assert parse_json_markdown(src) == {"k": [1, 2, 3]}
def test_parse_json_markdown_single_backticks():
src = '`{"x": true}`'
assert parse_json_markdown(src) == {"x": True}
def test_parse_json_markdown_braces_only():
src = ' {\n \t"ok": "yes"\n} '
assert parse_json_markdown(src) == {"ok": "yes"}
def test_parse_json_markdown_not_found():
with pytest.raises(ValueError):
parse_json_markdown("no json here")
def test_parse_and_check_json_markdown_missing_key():
src = """
```
{"present": 1}
```
"""
with pytest.raises(OutputParserError) as exc:
parse_and_check_json_markdown(src, ["present", "missing"])
assert "expected key `missing`" in str(exc.value)
def test_parse_and_check_json_markdown_invalid_json():
src = """
```json
{invalid json}
```
"""
with pytest.raises(OutputParserError) as exc:
parse_and_check_json_markdown(src, [])
assert "got invalid json object" in str(exc.value)
def test_parse_and_check_json_markdown_success():
src = """
```json
{"present": 1, "other": 2}
```
"""
obj = parse_and_check_json_markdown(src, ["present"])
assert obj == {"present": 1, "other": 2}
def test_parse_and_check_json_markdown_multiple_blocks_fails():
src = """
```json
{"a": 1}
```
Some text
```json
{"b": 2}
```
"""
# The current implementation is greedy and will match from the first
# opening fence to the last closing fence, causing JSON decode failure.
with pytest.raises(OutputParserError):
parse_and_check_json_markdown(src, [])
def test_parse_and_check_json_markdown_handles_think_fenced_and_raw_variants():
expected = {"keywords": ["2"], "category_id": "2", "category_name": "2"}
cases = [
"""
```json
[{"keywords": ["2"], "category_id": "2", "category_name": "2"}]
```, error: Expecting value: line 1 column 1 (char 0)
""",
"""
```json
{"keywords": ["2"], "category_id": "2", "category_name": "2"}
```, error: Extra data: line 2 column 5 (char 66)
""",
'{"keywords": ["2"], "category_id": "2", "category_name": "2"}',
'[{"keywords": ["2"], "category_id": "2", "category_name": "2"}]',
]
for src in cases:
obj = parse_and_check_json_markdown(src, ["keywords", "category_id", "category_name"])
assert obj == expected

View File

@@ -0,0 +1,63 @@
"""Test PyJWT import paths to catch changes in library structure."""
import pytest
class TestPyJWTImports:
"""Test PyJWT import paths used throughout the codebase."""
def test_invalid_token_error_import(self):
"""Test that InvalidTokenError can be imported as used in login controller."""
# This test verifies the import path used in controllers/web/login.py:2
# If PyJWT changes this import path, this test will fail early
try:
from jwt import InvalidTokenError
# Verify it's the correct exception class
assert issubclass(InvalidTokenError, Exception)
# Test that it can be instantiated
error = InvalidTokenError("test error")
assert str(error) == "test error"
except ImportError as e:
pytest.fail(f"Failed to import InvalidTokenError from jwt: {e}")
def test_jwt_exceptions_import(self):
"""Test that jwt.exceptions imports work as expected."""
# Alternative import path that might be used
try:
# Verify it's the same class as the direct import
from jwt import InvalidTokenError
from jwt.exceptions import InvalidTokenError as InvalidTokenErrorAlt
assert InvalidTokenError is InvalidTokenErrorAlt
except ImportError as e:
pytest.fail(f"Failed to import InvalidTokenError from jwt.exceptions: {e}")
def test_other_jwt_exceptions_available(self):
"""Test that other common JWT exceptions are available."""
# Test other exceptions that might be used in the codebase
try:
from jwt import DecodeError, ExpiredSignatureError, InvalidSignatureError
# Verify they are exception classes
assert issubclass(DecodeError, Exception)
assert issubclass(ExpiredSignatureError, Exception)
assert issubclass(InvalidSignatureError, Exception)
except ImportError as e:
pytest.fail(f"Failed to import JWT exceptions: {e}")
def test_jwt_main_functions_available(self):
"""Test that main JWT functions are available."""
try:
from jwt import decode, encode
# Verify they are callable
assert callable(decode)
assert callable(encode)
except ImportError as e:
pytest.fail(f"Failed to import JWT main functions: {e}")

View File

@@ -0,0 +1,243 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, g
from flask_login import LoginManager, UserMixin
from libs.login import _get_user, current_user, login_required
class MockUser(UserMixin):
"""Mock user class for testing."""
def __init__(self, id: str, is_authenticated: bool = True):
self.id = id
self._is_authenticated = is_authenticated
@property
def is_authenticated(self):
return self._is_authenticated
def mock_csrf_check(*args, **kwargs):
return
class TestLoginRequired:
"""Test cases for login_required decorator."""
@pytest.fixture
@patch("libs.login.check_csrf_token", mock_csrf_check)
def setup_app(self, app: Flask):
"""Set up Flask app with login manager."""
# Initialize login manager
login_manager = LoginManager()
login_manager.init_app(app)
# Mock unauthorized handler
login_manager.unauthorized = MagicMock(return_value="Unauthorized")
# Add a dummy user loader to prevent exceptions
@login_manager.user_loader
def load_user(user_id):
return None
return app
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_authenticated_user_can_access_protected_view(self, setup_app: Flask):
"""Test that authenticated users can access protected views."""
@login_required
def protected_view():
return "Protected content"
with setup_app.test_request_context():
# Mock authenticated user
mock_user = MockUser("test_user", is_authenticated=True)
with patch("libs.login._get_user", return_value=mock_user):
result = protected_view()
assert result == "Protected content"
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_unauthenticated_user_cannot_access_protected_view(self, setup_app: Flask):
"""Test that unauthenticated users are redirected."""
@login_required
def protected_view():
return "Protected content"
with setup_app.test_request_context():
# Mock unauthenticated user
mock_user = MockUser("test_user", is_authenticated=False)
with patch("libs.login._get_user", return_value=mock_user):
result = protected_view()
assert result == "Unauthorized"
setup_app.login_manager.unauthorized.assert_called_once()
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_login_disabled_allows_unauthenticated_access(self, setup_app: Flask):
"""Test that LOGIN_DISABLED config bypasses authentication."""
@login_required
def protected_view():
return "Protected content"
with setup_app.test_request_context():
# Mock unauthenticated user and LOGIN_DISABLED
mock_user = MockUser("test_user", is_authenticated=False)
with patch("libs.login._get_user", return_value=mock_user):
with patch("libs.login.dify_config") as mock_config:
mock_config.LOGIN_DISABLED = True
result = protected_view()
assert result == "Protected content"
# Ensure unauthorized was not called
setup_app.login_manager.unauthorized.assert_not_called()
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_options_request_bypasses_authentication(self, setup_app: Flask):
"""Test that OPTIONS requests are exempt from authentication."""
@login_required
def protected_view():
return "Protected content"
with setup_app.test_request_context(method="OPTIONS"):
# Mock unauthenticated user
mock_user = MockUser("test_user", is_authenticated=False)
with patch("libs.login._get_user", return_value=mock_user):
result = protected_view()
assert result == "Protected content"
# Ensure unauthorized was not called
setup_app.login_manager.unauthorized.assert_not_called()
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_flask_2_compatibility(self, setup_app: Flask):
"""Test Flask 2.x compatibility with ensure_sync."""
@login_required
def protected_view():
return "Protected content"
# Mock Flask 2.x ensure_sync
setup_app.ensure_sync = MagicMock(return_value=lambda: "Synced content")
with setup_app.test_request_context():
mock_user = MockUser("test_user", is_authenticated=True)
with patch("libs.login._get_user", return_value=mock_user):
result = protected_view()
assert result == "Synced content"
setup_app.ensure_sync.assert_called_once()
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_flask_1_compatibility(self, setup_app: Flask):
"""Test Flask 1.x compatibility without ensure_sync."""
@login_required
def protected_view():
return "Protected content"
# Remove ensure_sync to simulate Flask 1.x
if hasattr(setup_app, "ensure_sync"):
delattr(setup_app, "ensure_sync")
with setup_app.test_request_context():
mock_user = MockUser("test_user", is_authenticated=True)
with patch("libs.login._get_user", return_value=mock_user):
result = protected_view()
assert result == "Protected content"
class TestGetUser:
"""Test cases for _get_user function."""
def test_get_user_returns_user_from_g(self, app: Flask):
"""Test that _get_user returns user from g._login_user."""
mock_user = MockUser("test_user")
with app.test_request_context():
g._login_user = mock_user
user = _get_user()
assert user == mock_user
assert user.id == "test_user"
def test_get_user_loads_user_if_not_in_g(self, app: Flask):
"""Test that _get_user loads user if not already in g."""
mock_user = MockUser("test_user")
# Mock login manager
login_manager = MagicMock()
login_manager._load_user = MagicMock()
app.login_manager = login_manager
with app.test_request_context():
# Simulate _load_user setting g._login_user
def side_effect():
g._login_user = mock_user
login_manager._load_user.side_effect = side_effect
user = _get_user()
assert user == mock_user
login_manager._load_user.assert_called_once()
def test_get_user_returns_none_without_request_context(self, app: Flask):
"""Test that _get_user returns None outside request context."""
# Outside of request context
user = _get_user()
assert user is None
class TestCurrentUser:
"""Test cases for current_user proxy."""
def test_current_user_proxy_returns_authenticated_user(self, app: Flask):
"""Test that current_user proxy returns authenticated user."""
mock_user = MockUser("test_user", is_authenticated=True)
with app.test_request_context():
with patch("libs.login._get_user", return_value=mock_user):
assert current_user.id == "test_user"
assert current_user.is_authenticated is True
def test_current_user_proxy_returns_none_when_no_user(self, app: Flask):
"""Test that current_user proxy handles None user."""
with app.test_request_context():
with patch("libs.login._get_user", return_value=None):
# When _get_user returns None, accessing attributes should fail
# or current_user should evaluate to falsy
try:
# Try to access an attribute that would exist on a real user
_ = current_user.id
pytest.fail("Should have raised AttributeError")
except AttributeError:
# This is expected when current_user is None
pass
def test_current_user_proxy_thread_safety(self, app: Flask):
"""Test that current_user proxy is thread-safe."""
import threading
results = {}
def check_user_in_thread(user_id: str, index: int):
with app.test_request_context():
mock_user = MockUser(user_id)
with patch("libs.login._get_user", return_value=mock_user):
results[index] = current_user.id
# Create multiple threads with different users
threads = []
for i in range(5):
thread = threading.Thread(target=check_user_in_thread, args=(f"user_{i}", i))
threads.append(thread)
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join()
# Verify each thread got its own user
for i in range(5):
assert results[i] == f"user_{i}"

View File

@@ -0,0 +1,19 @@
import pytest
from libs.oauth import OAuth
def test_oauth_base_methods_raise_not_implemented():
oauth = OAuth(client_id="id", client_secret="sec", redirect_uri="uri")
with pytest.raises(NotImplementedError):
oauth.get_authorization_url()
with pytest.raises(NotImplementedError):
oauth.get_access_token("code")
with pytest.raises(NotImplementedError):
oauth.get_raw_user_info("token")
with pytest.raises(NotImplementedError):
oauth._transform_user_info({})

View File

@@ -0,0 +1,249 @@
import urllib.parse
from unittest.mock import MagicMock, patch
import httpx
import pytest
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
class BaseOAuthTest:
"""Base class for OAuth provider tests with common fixtures"""
@pytest.fixture
def oauth_config(self):
return {
"client_id": "test_client_id",
"client_secret": "test_client_secret",
"redirect_uri": "http://localhost/callback",
}
@pytest.fixture
def mock_response(self):
response = MagicMock()
response.json.return_value = {}
return response
def parse_auth_url(self, url):
"""Helper to parse authorization URL"""
parsed = urllib.parse.urlparse(url)
params = urllib.parse.parse_qs(parsed.query)
return parsed, params
class TestGitHubOAuth(BaseOAuthTest):
@pytest.fixture
def oauth(self, oauth_config):
return GitHubOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
@pytest.mark.parametrize(
("invite_token", "expected_state"),
[
(None, None),
("test_invite_token", "test_invite_token"),
("", None),
],
)
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
url = oauth.get_authorization_url(invite_token)
parsed, params = self.parse_auth_url(url)
assert parsed.scheme == "https"
assert parsed.netloc == "github.com"
assert parsed.path == "/login/oauth/authorize"
assert params["client_id"][0] == oauth_config["client_id"]
assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
assert params["scope"][0] == "user:email"
if expected_state:
assert params["state"][0] == expected_state
else:
assert "state" not in params
@pytest.mark.parametrize(
("response_data", "expected_token", "should_raise"),
[
({"access_token": "test_token"}, "test_token", False),
({"error": "invalid_grant"}, None, True),
({}, None, True),
],
)
@patch("httpx.post")
def test_should_retrieve_access_token(
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
):
mock_response.json.return_value = response_data
mock_post.return_value = mock_response
if should_raise:
with pytest.raises(ValueError) as exc_info:
oauth.get_access_token("test_code")
assert "Error in GitHub OAuth" in str(exc_info.value)
else:
token = oauth.get_access_token("test_code")
assert token == expected_token
@pytest.mark.parametrize(
("user_data", "email_data", "expected_email"),
[
# User with primary email
(
{"id": 12345, "login": "testuser", "name": "Test User"},
[
{"email": "secondary@example.com", "primary": False},
{"email": "primary@example.com", "primary": True},
],
"primary@example.com",
),
# User with no emails - fallback to noreply
({"id": 12345, "login": "testuser", "name": "Test User"}, [], "12345+testuser@users.noreply.github.com"),
# User with only secondary email - fallback to noreply
(
{"id": 12345, "login": "testuser", "name": "Test User"},
[{"email": "secondary@example.com", "primary": False}],
"12345+testuser@users.noreply.github.com",
),
],
)
@patch("httpx.get")
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
user_response = MagicMock()
user_response.json.return_value = user_data
email_response = MagicMock()
email_response.json.return_value = email_data
mock_get.side_effect = [user_response, email_response]
user_info = oauth.get_user_info("test_token")
assert user_info.id == str(user_data["id"])
assert user_info.name == user_data["name"]
assert user_info.email == expected_email
@patch("httpx.get")
def test_should_handle_network_errors(self, mock_get, oauth):
mock_get.side_effect = httpx.RequestError("Network error")
with pytest.raises(httpx.RequestError):
oauth.get_raw_user_info("test_token")
class TestGoogleOAuth(BaseOAuthTest):
@pytest.fixture
def oauth(self, oauth_config):
return GoogleOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
@pytest.mark.parametrize(
("invite_token", "expected_state"),
[
(None, None),
("test_invite_token", "test_invite_token"),
("", None),
],
)
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
url = oauth.get_authorization_url(invite_token)
parsed, params = self.parse_auth_url(url)
assert parsed.scheme == "https"
assert parsed.netloc == "accounts.google.com"
assert parsed.path == "/o/oauth2/v2/auth"
assert params["client_id"][0] == oauth_config["client_id"]
assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
assert params["response_type"][0] == "code"
assert params["scope"][0] == "openid email"
if expected_state:
assert params["state"][0] == expected_state
else:
assert "state" not in params
@pytest.mark.parametrize(
("response_data", "expected_token", "should_raise"),
[
({"access_token": "test_token"}, "test_token", False),
({"error": "invalid_grant"}, None, True),
({}, None, True),
],
)
@patch("httpx.post")
def test_should_retrieve_access_token(
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
):
mock_response.json.return_value = response_data
mock_post.return_value = mock_response
if should_raise:
with pytest.raises(ValueError) as exc_info:
oauth.get_access_token("test_code")
assert "Error in Google OAuth" in str(exc_info.value)
else:
token = oauth.get_access_token("test_code")
assert token == expected_token
mock_post.assert_called_once_with(
oauth._TOKEN_URL,
data={
"client_id": oauth_config["client_id"],
"client_secret": oauth_config["client_secret"],
"code": "test_code",
"grant_type": "authorization_code",
"redirect_uri": oauth_config["redirect_uri"],
},
headers={"Accept": "application/json"},
)
@pytest.mark.parametrize(
("user_data", "expected_name"),
[
({"sub": "123", "email": "test@example.com", "email_verified": True}, ""),
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
],
)
@patch("httpx.get")
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
mock_response.json.return_value = user_data
mock_get.return_value = mock_response
user_info = oauth.get_user_info("test_token")
assert user_info.id == user_data["sub"]
assert user_info.name == expected_name
assert user_info.email == user_data["email"]
mock_get.assert_called_once_with(oauth._USER_INFO_URL, headers={"Authorization": "Bearer test_token"})
@pytest.mark.parametrize(
"exception_type",
[
httpx.HTTPError,
httpx.ConnectError,
httpx.TimeoutException,
],
)
@patch("httpx.get")
def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = exception_type("Error")
mock_get.return_value = mock_response
with pytest.raises(exception_type):
oauth.get_raw_user_info("invalid_token")
class TestOAuthUserInfo:
@pytest.mark.parametrize(
"user_data",
[
{"id": "123", "name": "Test User", "email": "test@example.com"},
{"id": "456", "name": "", "email": "user@domain.com"},
{"id": "789", "name": "Another User", "email": "another@test.org"},
],
)
def test_should_create_user_info_dataclass(self, user_data):
user_info = OAuthUserInfo(**user_data)
assert user_info.id == user_data["id"]
assert user_info.name == user_data["name"]
assert user_info.email == user_data["email"]

View File

@@ -0,0 +1,25 @@
import orjson
import pytest
from libs.orjson import orjson_dumps
def test_orjson_dumps_round_trip_basic():
obj = {"a": 1, "b": [1, 2, 3], "c": {"d": True}}
s = orjson_dumps(obj)
assert orjson.loads(s) == obj
def test_orjson_dumps_with_unicode_and_indent():
obj = {"msg": "你好Dify"}
s = orjson_dumps(obj, option=orjson.OPT_INDENT_2)
# contains indentation newline/spaces
assert "\n" in s
assert orjson.loads(s) == obj
def test_orjson_dumps_non_utf8_encoding_fails():
obj = {"msg": "你好"}
# orjson.dumps() always produces UTF-8 bytes; decoding with non-UTF8 fails.
with pytest.raises(UnicodeDecodeError):
orjson_dumps(obj, encoding="ascii")

View File

@@ -0,0 +1,58 @@
import pandas as pd
def test_pandas_csv(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]}
df1 = pd.DataFrame(data)
# write to csv file
csv_file_path = tmp_path.joinpath("example.csv")
df1.to_csv(csv_file_path, index=False)
# read from csv file
df2 = pd.read_csv(csv_file_path, on_bad_lines="skip")
assert df2[df2.columns[0]].to_list() == data["col1"]
assert df2[df2.columns[1]].to_list() == data["col2"]
def test_pandas_xlsx(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]}
df1 = pd.DataFrame(data)
# write to xlsx file
xlsx_file_path = tmp_path.joinpath("example.xlsx")
df1.to_excel(xlsx_file_path, index=False)
# read from xlsx file
df2 = pd.read_excel(xlsx_file_path)
assert df2[df2.columns[0]].to_list() == data["col1"]
assert df2[df2.columns[1]].to_list() == data["col2"]
def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
data1 = {"col1": [1, 2, 3, 4, 5], "col2": ["A", "B", "C", "D", "E"]}
df1 = pd.DataFrame(data1)
data2 = {"col1": [6, 7, 8, 9, 10], "col2": ["F", "G", "H", "I", "J"]}
df2 = pd.DataFrame(data2)
# write to xlsx file with sheets
xlsx_file_path = tmp_path.joinpath("example_with_sheets.xlsx")
sheet1 = "Sheet1"
sheet2 = "Sheet2"
with pd.ExcelWriter(xlsx_file_path) as excel_writer:
df1.to_excel(excel_writer, sheet_name=sheet1, index=False)
df2.to_excel(excel_writer, sheet_name=sheet2, index=False)
# read from xlsx file with sheets
with pd.ExcelFile(xlsx_file_path) as excel_file:
df1 = pd.read_excel(excel_file, sheet_name=sheet1)
assert df1[df1.columns[0]].to_list() == data1["col1"]
assert df1[df1.columns[1]].to_list() == data1["col2"]
df2 = pd.read_excel(excel_file, sheet_name=sheet2)
assert df2[df2.columns[0]].to_list() == data2["col1"]
assert df2[df2.columns[1]].to_list() == data2["col2"]

View File

@@ -0,0 +1,205 @@
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
import jwt
import pytest
from werkzeug.exceptions import Unauthorized
from libs.passport import PassportService
class TestPassportService:
"""Test PassportService JWT operations"""
@pytest.fixture
def passport_service(self):
"""Create PassportService instance with test secret key"""
with patch("libs.passport.dify_config") as mock_config:
mock_config.SECRET_KEY = "test-secret-key-for-testing"
return PassportService()
@pytest.fixture
def another_passport_service(self):
"""Create another PassportService instance with different secret key"""
with patch("libs.passport.dify_config") as mock_config:
mock_config.SECRET_KEY = "another-secret-key-for-testing"
return PassportService()
# Core functionality tests
def test_should_issue_and_verify_token(self, passport_service):
"""Test complete JWT lifecycle: issue and verify"""
payload = {"user_id": "123", "app_code": "test-app"}
token = passport_service.issue(payload)
# Verify token format
assert isinstance(token, str)
assert len(token.split(".")) == 3 # JWT format: header.payload.signature
# Verify token content
decoded = passport_service.verify(token)
assert decoded == payload
def test_should_handle_different_payload_types(self, passport_service):
"""Test issuing and verifying tokens with different payload types"""
test_cases = [
{"string": "value"},
{"number": 42},
{"float": 3.14},
{"boolean": True},
{"null": None},
{"array": [1, 2, 3]},
{"nested": {"key": "value"}},
{"unicode": "中文测试"},
{"emoji": "🔐"},
{}, # Empty payload
]
for payload in test_cases:
token = passport_service.issue(payload)
decoded = passport_service.verify(token)
assert decoded == payload
# Security tests
def test_should_reject_modified_token(self, passport_service):
"""Test that any modification to token invalidates it"""
token = passport_service.issue({"user": "test"})
# Test multiple modification points
test_positions = [0, len(token) // 3, len(token) // 2, len(token) - 1]
for pos in test_positions:
if pos < len(token) and token[pos] != ".":
# Change one character
tampered = token[:pos] + ("X" if token[pos] != "X" else "Y") + token[pos + 1 :]
with pytest.raises(Unauthorized):
passport_service.verify(tampered)
def test_should_reject_token_with_different_secret_key(self, passport_service, another_passport_service):
"""Test key isolation - token from one service should not work with another"""
payload = {"user_id": "123", "app_code": "test-app"}
token = passport_service.issue(payload)
with pytest.raises(Unauthorized) as exc_info:
another_passport_service.verify(token)
assert str(exc_info.value) == "401 Unauthorized: Invalid token signature."
def test_should_use_hs256_algorithm(self, passport_service):
"""Test that HS256 algorithm is used for signing"""
payload = {"test": "data"}
token = passport_service.issue(payload)
# Decode header without relying on JWT internals
# Use jwt.get_unverified_header which is a public API
header = jwt.get_unverified_header(token)
assert header["alg"] == "HS256"
def test_should_reject_token_with_wrong_algorithm(self, passport_service):
"""Test rejection of token signed with different algorithm"""
payload = {"user_id": "123"}
# Create token with different algorithm
with patch("libs.passport.dify_config") as mock_config:
mock_config.SECRET_KEY = "test-secret-key-for-testing"
# Create token with HS512 instead of HS256
wrong_alg_token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS512")
# Should fail because service expects HS256
# InvalidAlgorithmError is now caught by PyJWTError handler
with pytest.raises(Unauthorized) as exc_info:
passport_service.verify(wrong_alg_token)
assert str(exc_info.value) == "401 Unauthorized: Invalid token."
# Exception handling tests
def test_should_handle_invalid_tokens(self, passport_service):
"""Test handling of various invalid token formats"""
invalid_tokens = [
("not.a.token", "Invalid token."),
("invalid-jwt-format", "Invalid token."),
("xxx.yyy.zzz", "Invalid token."),
("a.b", "Invalid token."), # Missing signature
("", "Invalid token."), # Empty string
(" ", "Invalid token."), # Whitespace
(None, "Invalid token."), # None value
# Malformed base64
("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.INVALID_BASE64!@#$.signature", "Invalid token."),
]
for invalid_token, expected_message in invalid_tokens:
with pytest.raises(Unauthorized) as exc_info:
passport_service.verify(invalid_token)
assert expected_message in str(exc_info.value)
def test_should_reject_expired_token(self, passport_service):
"""Test rejection of expired token"""
past_time = datetime.now(UTC) - timedelta(hours=1)
payload = {"user_id": "123", "exp": past_time.timestamp()}
with patch("libs.passport.dify_config") as mock_config:
mock_config.SECRET_KEY = "test-secret-key-for-testing"
token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS256")
with pytest.raises(Unauthorized) as exc_info:
passport_service.verify(token)
assert str(exc_info.value) == "401 Unauthorized: Token has expired."
# Configuration tests
def test_should_handle_empty_secret_key(self):
"""Test behavior when SECRET_KEY is empty"""
with patch("libs.passport.dify_config") as mock_config:
mock_config.SECRET_KEY = ""
service = PassportService()
# Empty secret key should still work but is insecure
payload = {"test": "data"}
token = service.issue(payload)
decoded = service.verify(token)
assert decoded == payload
def test_should_handle_none_secret_key(self):
"""Test behavior when SECRET_KEY is None"""
with patch("libs.passport.dify_config") as mock_config:
mock_config.SECRET_KEY = None
service = PassportService()
payload = {"test": "data"}
# JWT library will raise TypeError when secret is None
with pytest.raises((TypeError, jwt.exceptions.InvalidKeyError)):
service.issue(payload)
# Boundary condition tests
def test_should_handle_large_payload(self, passport_service):
"""Test handling of large payload"""
# Test with 100KB instead of 1MB for faster tests
large_data = "x" * (100 * 1024)
payload = {"data": large_data}
token = passport_service.issue(payload)
decoded = passport_service.verify(token)
assert decoded["data"] == large_data
def test_should_handle_special_characters_in_payload(self, passport_service):
"""Test handling of special characters in payload"""
special_payloads = [
{"special": "!@#$%^&*()"},
{"quotes": 'He said "Hello"'},
{"backslash": "path\\to\\file"},
{"newline": "line1\nline2"},
{"unicode": "🔐🔑🛡️"},
{"mixed": "Test123!@#中文🔐"},
]
for payload in special_payloads:
token = passport_service.issue(payload)
decoded = passport_service.verify(token)
assert decoded == payload
def test_should_catch_generic_pyjwt_errors(self, passport_service):
"""Test that generic PyJWTError exceptions are caught and converted to Unauthorized"""
# Mock jwt.decode to raise a generic PyJWTError
with patch("libs.passport.jwt.decode") as mock_decode:
mock_decode.side_effect = jwt.exceptions.PyJWTError("Generic JWT error")
with pytest.raises(Unauthorized) as exc_info:
passport_service.verify("some-token")
assert str(exc_info.value) == "401 Unauthorized: Invalid token."

View File

@@ -0,0 +1,74 @@
import base64
import binascii
import os
import pytest
from libs.password import compare_password, hash_password, valid_password
class TestValidPassword:
"""Test password format validation"""
def test_should_accept_valid_passwords(self):
"""Test accepting valid password formats"""
assert valid_password("password123") == "password123"
assert valid_password("test1234") == "test1234"
assert valid_password("Test123456") == "Test123456"
def test_should_reject_invalid_passwords(self):
"""Test rejecting invalid password formats"""
# Too short
with pytest.raises(ValueError) as exc_info:
valid_password("abc123")
assert "Password must contain letters and numbers" in str(exc_info.value)
# No numbers
with pytest.raises(ValueError):
valid_password("abcdefgh")
# No letters
with pytest.raises(ValueError):
valid_password("12345678")
# Empty
with pytest.raises(ValueError):
valid_password("")
class TestPasswordHashing:
"""Test password hashing and comparison"""
def setup_method(self):
"""Setup test data"""
self.password = "test123password"
self.salt = os.urandom(16)
self.salt_base64 = base64.b64encode(self.salt).decode()
password_hash = hash_password(self.password, self.salt)
self.password_hash_base64 = base64.b64encode(password_hash).decode()
def test_should_verify_correct_password(self):
"""Test correct password verification"""
result = compare_password(self.password, self.password_hash_base64, self.salt_base64)
assert result is True
def test_should_reject_wrong_password(self):
"""Test rejection of incorrect passwords"""
result = compare_password("wrongpassword", self.password_hash_base64, self.salt_base64)
assert result is False
def test_should_handle_invalid_base64(self):
"""Test handling of invalid base64 data"""
# Invalid base64 hash
with pytest.raises(binascii.Error):
compare_password(self.password, "invalid_base64!", self.salt_base64)
# Invalid base64 salt
with pytest.raises(binascii.Error):
compare_password(self.password, self.password_hash_base64, "invalid_base64!")
def test_should_be_case_sensitive(self):
"""Test password case sensitivity"""
result = compare_password(self.password.upper(), self.password_hash_base64, self.salt_base64)
assert result is False

View File

@@ -0,0 +1,29 @@
import rsa as pyrsa
from Crypto.PublicKey import RSA
from libs import gmpy2_pkcs10aep_cipher
def test_gmpy2_pkcs10aep_cipher():
rsa_key_pair = pyrsa.newkeys(2048)
public_key = rsa_key_pair[0].save_pkcs1()
private_key = rsa_key_pair[1].save_pkcs1()
public_rsa_key = RSA.import_key(public_key)
public_cipher_rsa2 = gmpy2_pkcs10aep_cipher.new(public_rsa_key)
private_rsa_key = RSA.import_key(private_key)
private_cipher_rsa = gmpy2_pkcs10aep_cipher.new(private_rsa_key)
raw_text = "raw_text"
raw_text_bytes = raw_text.encode()
# RSA encryption by public key and decryption by private key
encrypted_by_pub_key = public_cipher_rsa2.encrypt(message=raw_text_bytes)
decrypted_by_pub_key = private_cipher_rsa.decrypt(encrypted_by_pub_key)
assert decrypted_by_pub_key == raw_text_bytes
# RSA encryption and decryption by private key
encrypted_by_private_key = private_cipher_rsa.encrypt(message=raw_text_bytes)
decrypted_by_private_key = private_cipher_rsa.decrypt(encrypted_by_private_key)
assert decrypted_by_private_key == raw_text_bytes

View File

@@ -0,0 +1,411 @@
"""
Enhanced schedule_utils tests for new cron syntax support.
These tests verify that the backend schedule_utils functions properly support
the enhanced cron syntax introduced in the frontend, ensuring full compatibility.
"""
import unittest
from datetime import UTC, datetime, timedelta
import pytest
import pytz
from croniter import CroniterBadCronError
from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
class TestEnhancedCronSyntax(unittest.TestCase):
"""Test enhanced cron syntax in calculate_next_run_at."""
def setUp(self):
"""Set up test with fixed time."""
# Monday, January 15, 2024, 10:00 AM UTC
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
def test_month_abbreviations(self):
"""Test month abbreviations (JAN, FEB, etc.)."""
test_cases = [
("0 12 1 JAN *", 1), # January
("0 12 1 FEB *", 2), # February
("0 12 1 MAR *", 3), # March
("0 12 1 APR *", 4), # April
("0 12 1 MAY *", 5), # May
("0 12 1 JUN *", 6), # June
("0 12 1 JUL *", 7), # July
("0 12 1 AUG *", 8), # August
("0 12 1 SEP *", 9), # September
("0 12 1 OCT *", 10), # October
("0 12 1 NOV *", 11), # November
("0 12 1 DEC *", 12), # December
]
for expr, expected_month in test_cases:
with self.subTest(expr=expr):
result = calculate_next_run_at(expr, "UTC", self.base_time)
assert result is not None, f"Failed to parse: {expr}"
assert result.month == expected_month
assert result.day == 1
assert result.hour == 12
assert result.minute == 0
def test_weekday_abbreviations(self):
"""Test weekday abbreviations (SUN, MON, etc.)."""
test_cases = [
("0 9 * * SUN", 6), # Sunday (weekday() = 6)
("0 9 * * MON", 0), # Monday (weekday() = 0)
("0 9 * * TUE", 1), # Tuesday
("0 9 * * WED", 2), # Wednesday
("0 9 * * THU", 3), # Thursday
("0 9 * * FRI", 4), # Friday
("0 9 * * SAT", 5), # Saturday
]
for expr, expected_weekday in test_cases:
with self.subTest(expr=expr):
result = calculate_next_run_at(expr, "UTC", self.base_time)
assert result is not None, f"Failed to parse: {expr}"
assert result.weekday() == expected_weekday
assert result.hour == 9
assert result.minute == 0
def test_sunday_dual_representation(self):
"""Test Sunday as both 0 and 7."""
base_time = datetime(2024, 1, 14, 10, 0, 0, tzinfo=UTC) # Sunday
# Both should give the same next Sunday
result_0 = calculate_next_run_at("0 10 * * 0", "UTC", base_time)
result_7 = calculate_next_run_at("0 10 * * 7", "UTC", base_time)
result_SUN = calculate_next_run_at("0 10 * * SUN", "UTC", base_time)
assert result_0 is not None
assert result_7 is not None
assert result_SUN is not None
# All should be Sundays
assert result_0.weekday() == 6 # Sunday = 6 in weekday()
assert result_7.weekday() == 6
assert result_SUN.weekday() == 6
# Times should be identical
assert result_0 == result_7
assert result_0 == result_SUN
def test_predefined_expressions(self):
"""Test predefined expressions (@daily, @weekly, etc.)."""
test_cases = [
("@yearly", lambda dt: dt.month == 1 and dt.day == 1 and dt.hour == 0 and dt.minute == 0),
("@annually", lambda dt: dt.month == 1 and dt.day == 1 and dt.hour == 0 and dt.minute == 0),
("@monthly", lambda dt: dt.day == 1 and dt.hour == 0 and dt.minute == 0),
("@weekly", lambda dt: dt.weekday() == 6 and dt.hour == 0 and dt.minute == 0), # Sunday
("@daily", lambda dt: dt.hour == 0 and dt.minute == 0),
("@midnight", lambda dt: dt.hour == 0 and dt.minute == 0),
("@hourly", lambda dt: dt.minute == 0),
]
for expr, validator in test_cases:
with self.subTest(expr=expr):
result = calculate_next_run_at(expr, "UTC", self.base_time)
assert result is not None, f"Failed to parse: {expr}"
assert validator(result), f"Validator failed for {expr}: {result}"
def test_question_mark_wildcard(self):
"""Test ? wildcard character."""
# ? in day position with specific weekday
result_question = calculate_next_run_at("0 9 ? * 1", "UTC", self.base_time) # Monday
result_star = calculate_next_run_at("0 9 * * 1", "UTC", self.base_time) # Monday
assert result_question is not None
assert result_star is not None
# Both should return Mondays at 9:00
assert result_question.weekday() == 0 # Monday
assert result_star.weekday() == 0
assert result_question.hour == 9
assert result_star.hour == 9
# Results should be identical
assert result_question == result_star
def test_last_day_of_month(self):
"""Test 'L' for last day of month."""
expr = "0 12 L * *" # Last day of month at noon
# Test for February (28 days in 2024 - not a leap year check)
feb_base = datetime(2024, 2, 15, 10, 0, 0, tzinfo=UTC)
result = calculate_next_run_at(expr, "UTC", feb_base)
assert result is not None
assert result.month == 2
assert result.day == 29 # 2024 is a leap year
assert result.hour == 12
def test_range_with_abbreviations(self):
"""Test ranges using abbreviations."""
test_cases = [
"0 9 * * MON-FRI", # Weekday range
"0 12 * JAN-MAR *", # Q1 months
"0 15 * APR-JUN *", # Q2 months
]
for expr in test_cases:
with self.subTest(expr=expr):
result = calculate_next_run_at(expr, "UTC", self.base_time)
assert result is not None, f"Failed to parse range expression: {expr}"
assert result > self.base_time
def test_list_with_abbreviations(self):
"""Test lists using abbreviations."""
test_cases = [
("0 9 * * SUN,WED,FRI", [6, 2, 4]), # Specific weekdays
("0 12 1 JAN,JUN,DEC *", [1, 6, 12]), # Specific months
]
for expr, expected_values in test_cases:
with self.subTest(expr=expr):
result = calculate_next_run_at(expr, "UTC", self.base_time)
assert result is not None, f"Failed to parse list expression: {expr}"
if "* *" in expr: # Weekday test
assert result.weekday() in expected_values
else: # Month test
assert result.month in expected_values
def test_mixed_syntax(self):
"""Test mixed traditional and enhanced syntax."""
test_cases = [
"30 14 15 JAN,JUN,DEC *", # Numbers + month abbreviations
"0 9 * JAN-MAR MON-FRI", # Month range + weekday range
"45 8 1,15 * MON", # Numbers + weekday abbreviation
]
for expr in test_cases:
with self.subTest(expr=expr):
result = calculate_next_run_at(expr, "UTC", self.base_time)
assert result is not None, f"Failed to parse mixed syntax: {expr}"
assert result > self.base_time
def test_complex_enhanced_expressions(self):
"""Test complex expressions with multiple enhanced features."""
# Note: Some of these might not be supported by croniter, that's OK
complex_expressions = [
"0 9 L JAN *", # Last day of January
"30 14 * * FRI#1", # First Friday of month (if supported)
"0 12 15 JAN-DEC/3 *", # 15th of every 3rd month (quarterly)
]
for expr in complex_expressions:
with self.subTest(expr=expr):
try:
result = calculate_next_run_at(expr, "UTC", self.base_time)
if result: # If supported, should return valid result
assert result > self.base_time
except Exception:
# Some complex expressions might not be supported - that's acceptable
pass
class TestTimezoneHandlingEnhanced(unittest.TestCase):
"""Test timezone handling with enhanced syntax."""
def setUp(self):
"""Set up test with fixed time."""
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
def test_enhanced_syntax_with_timezones(self):
"""Test enhanced syntax works correctly across timezones."""
timezones = ["UTC", "America/New_York", "Asia/Tokyo", "Europe/London"]
expression = "0 12 * * MON" # Monday at noon
for timezone in timezones:
with self.subTest(timezone=timezone):
result = calculate_next_run_at(expression, timezone, self.base_time)
assert result is not None
# Convert to local timezone to verify it's Monday at noon
tz = pytz.timezone(timezone)
local_time = result.astimezone(tz)
assert local_time.weekday() == 0 # Monday
assert local_time.hour == 12
assert local_time.minute == 0
def test_predefined_expressions_with_timezones(self):
"""Test predefined expressions work with different timezones."""
expression = "@daily"
timezones = ["UTC", "America/New_York", "Asia/Tokyo"]
for timezone in timezones:
with self.subTest(timezone=timezone):
result = calculate_next_run_at(expression, timezone, self.base_time)
assert result is not None
# Should be midnight in the specified timezone
tz = pytz.timezone(timezone)
local_time = result.astimezone(tz)
assert local_time.hour == 0
assert local_time.minute == 0
def test_dst_with_enhanced_syntax(self):
"""Test DST handling with enhanced syntax."""
# DST spring forward date in 2024
dst_base = datetime(2024, 3, 8, 10, 0, 0, tzinfo=UTC)
expression = "0 2 * * SUN" # Sunday at 2 AM (problematic during DST)
timezone = "America/New_York"
result = calculate_next_run_at(expression, timezone, dst_base)
assert result is not None
# Should handle DST transition gracefully
tz = pytz.timezone(timezone)
local_time = result.astimezone(tz)
assert local_time.weekday() == 6 # Sunday
# During DST spring forward, 2 AM might become 3 AM
assert local_time.hour in [2, 3]
class TestErrorHandlingEnhanced(unittest.TestCase):
"""Test error handling for enhanced syntax."""
def setUp(self):
"""Set up test with fixed time."""
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
def test_invalid_enhanced_syntax(self):
"""Test that invalid enhanced syntax raises appropriate errors."""
invalid_expressions = [
"0 12 * JANUARY *", # Full month name
"0 12 * * MONDAY", # Full day name
"0 12 32 JAN *", # Invalid day with valid month
"0 12 * * MON-SUN-FRI", # Invalid range syntax
"0 12 * JAN- *", # Incomplete range
"0 12 * * ,MON", # Invalid list syntax
"@INVALID", # Invalid predefined
]
for expr in invalid_expressions:
with self.subTest(expr=expr):
with pytest.raises((CroniterBadCronError, ValueError)):
calculate_next_run_at(expr, "UTC", self.base_time)
def test_boundary_values_with_enhanced_syntax(self):
"""Test boundary values work with enhanced syntax."""
# Valid boundary expressions
valid_expressions = [
"0 0 1 JAN *", # Minimum: January 1st midnight
"59 23 31 DEC *", # Maximum: December 31st 23:59
"0 12 29 FEB *", # Leap year boundary
]
for expr in valid_expressions:
with self.subTest(expr=expr):
try:
result = calculate_next_run_at(expr, "UTC", self.base_time)
if result: # Some dates might not occur soon
assert result > self.base_time
except Exception as e:
# Some boundary cases might be complex to calculate
self.fail(f"Valid boundary expression failed: {expr} - {e}")
class TestPerformanceEnhanced(unittest.TestCase):
"""Test performance with enhanced syntax."""
def setUp(self):
"""Set up test with fixed time."""
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
def test_complex_expression_performance(self):
"""Test that complex enhanced expressions parse within reasonable time."""
import time
complex_expressions = [
"*/5 9-17 * * MON-FRI", # Every 5 min, weekdays, business hours
"0 9 * JAN-MAR MON-FRI", # Q1 weekdays at 9 AM
"30 14 1,15 * * ", # 1st and 15th at 14:30
"0 12 ? * SUN", # Sundays at noon with ?
"@daily", # Predefined expression
]
start_time = time.time()
for expr in complex_expressions:
with self.subTest(expr=expr):
try:
result = calculate_next_run_at(expr, "UTC", self.base_time)
assert result is not None
except Exception:
# Some expressions might not be supported - acceptable
pass
end_time = time.time()
execution_time = (end_time - start_time) * 1000 # milliseconds
# Should be fast (less than 100ms for all expressions)
assert execution_time < 100, "Enhanced expressions should parse quickly"
def test_multiple_calculations_performance(self):
"""Test performance when calculating multiple next times."""
import time
expression = "0 9 * * MON-FRI" # Weekdays at 9 AM
iterations = 20
start_time = time.time()
current_time = self.base_time
for _ in range(iterations):
result = calculate_next_run_at(expression, "UTC", current_time)
assert result is not None
current_time = result + timedelta(seconds=1) # Move forward slightly
end_time = time.time()
total_time = (end_time - start_time) * 1000 # milliseconds
avg_time = total_time / iterations
# Average should be very fast (less than 5ms per calculation)
assert avg_time < 5, f"Average calculation time too slow: {avg_time}ms"
class TestRegressionEnhanced(unittest.TestCase):
"""Regression tests to ensure enhanced syntax doesn't break existing functionality."""
def setUp(self):
"""Set up test with fixed time."""
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
def test_traditional_syntax_still_works(self):
"""Ensure traditional cron syntax continues to work."""
traditional_expressions = [
"15 10 1 * *", # Monthly 1st at 10:15
"0 0 * * 0", # Weekly Sunday midnight
"*/5 * * * *", # Every 5 minutes
"0 9-17 * * 1-5", # Business hours weekdays
"30 14 * * 1", # Monday 14:30
"0 0 1,15 * *", # 1st and 15th midnight
]
for expr in traditional_expressions:
with self.subTest(expr=expr):
result = calculate_next_run_at(expr, "UTC", self.base_time)
assert result is not None, f"Traditional expression failed: {expr}"
assert result > self.base_time
def test_convert_12h_to_24h_unchanged(self):
"""Ensure convert_12h_to_24h function is unchanged."""
test_cases = [
("12:00 AM", (0, 0)), # Midnight
("12:00 PM", (12, 0)), # Noon
("1:30 AM", (1, 30)), # Early morning
("11:45 PM", (23, 45)), # Late evening
("6:15 AM", (6, 15)), # Morning
("3:30 PM", (15, 30)), # Afternoon
]
for time_str, expected in test_cases:
with self.subTest(time_str=time_str):
result = convert_12h_to_24h(time_str)
assert result == expected, f"12h conversion failed: {time_str}"
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,53 @@
from unittest.mock import MagicMock, patch
import pytest
from python_http_client.exceptions import UnauthorizedError
from libs.sendgrid import SendGridClient
def _mail(to: str = "user@example.com") -> dict:
return {"to": to, "subject": "Hi", "html": "<b>Hi</b>"}
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
def test_sendgrid_success(mock_client_cls: MagicMock):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
# nested attribute access: client.mail.send.post
mock_client.client.mail.send.post.return_value = MagicMock(status_code=202, body=b"", headers={})
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
sg.send(_mail())
mock_client_cls.assert_called_once()
mock_client.client.mail.send.post.assert_called_once()
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
def test_sendgrid_missing_to_raises(mock_client_cls: MagicMock):
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
with pytest.raises(ValueError):
sg.send(_mail(to=""))
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
def test_sendgrid_auth_errors_reraise(mock_client_cls: MagicMock):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.client.mail.send.post.side_effect = UnauthorizedError(401, "Unauthorized", b"{}", {})
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
with pytest.raises(UnauthorizedError):
sg.send(_mail())
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
def test_sendgrid_timeout_reraise(mock_client_cls: MagicMock):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.client.mail.send.post.side_effect = TimeoutError("timeout")
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
with pytest.raises(TimeoutError):
sg.send(_mail())

View File

@@ -0,0 +1,100 @@
from unittest.mock import MagicMock, patch
import pytest
from libs.smtp import SMTPClient
def _mail() -> dict:
return {"to": "user@example.com", "subject": "Hi", "html": "<b>Hi</b>"}
@patch("libs.smtp.smtplib.SMTP")
def test_smtp_plain_success(mock_smtp_cls: MagicMock):
mock_smtp = MagicMock()
mock_smtp_cls.return_value = mock_smtp
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
client.send(_mail())
mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10)
mock_smtp.sendmail.assert_called_once()
mock_smtp.quit.assert_called_once()
@patch("libs.smtp.smtplib.SMTP")
def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock):
mock_smtp = MagicMock()
mock_smtp_cls.return_value = mock_smtp
client = SMTPClient(
server="smtp.example.com",
port=587,
username="user",
password="pass",
_from="noreply@example.com",
use_tls=True,
opportunistic_tls=True,
)
client.send(_mail())
mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10)
assert mock_smtp.ehlo.call_count == 2
mock_smtp.starttls.assert_called_once()
mock_smtp.login.assert_called_once_with("user", "pass")
mock_smtp.sendmail.assert_called_once()
mock_smtp.quit.assert_called_once()
@patch("libs.smtp.smtplib.SMTP_SSL")
def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock):
# Cover SMTP_SSL branch and TimeoutError handling
mock_smtp = MagicMock()
mock_smtp.sendmail.side_effect = TimeoutError("timeout")
mock_smtp_ssl_cls.return_value = mock_smtp
client = SMTPClient(
server="smtp.example.com",
port=465,
username="",
password="",
_from="noreply@example.com",
use_tls=True,
opportunistic_tls=False,
)
with pytest.raises(TimeoutError):
client.send(_mail())
mock_smtp.quit.assert_called_once()
@patch("libs.smtp.smtplib.SMTP")
def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock):
mock_smtp = MagicMock()
mock_smtp.sendmail.side_effect = RuntimeError("oops")
mock_smtp_cls.return_value = mock_smtp
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
with pytest.raises(RuntimeError):
client.send(_mail())
mock_smtp.quit.assert_called_once()
@patch("libs.smtp.smtplib.SMTP")
def test_smtp_smtplib_exception_in_login(mock_smtp_cls: MagicMock):
# Ensure we hit the specific SMTPException except branch
import smtplib
mock_smtp = MagicMock()
mock_smtp.login.side_effect = smtplib.SMTPException("login-fail")
mock_smtp_cls.return_value = mock_smtp
client = SMTPClient(
server="smtp.example.com",
port=25,
username="user", # non-empty to trigger login
password="pass",
_from="noreply@example.com",
)
with pytest.raises(smtplib.SMTPException):
client.send(_mail())
mock_smtp.quit.assert_called_once()

View File

@@ -0,0 +1,91 @@
"""Unit tests for time parser utility."""
from datetime import UTC, datetime, timedelta
from libs.time_parser import get_time_threshold, parse_time_duration
class TestParseTimeDuration:
"""Test parse_time_duration function."""
def test_parse_days(self):
"""Test parsing days."""
result = parse_time_duration("7d")
assert result == timedelta(days=7)
def test_parse_hours(self):
"""Test parsing hours."""
result = parse_time_duration("4h")
assert result == timedelta(hours=4)
def test_parse_minutes(self):
"""Test parsing minutes."""
result = parse_time_duration("30m")
assert result == timedelta(minutes=30)
def test_parse_seconds(self):
"""Test parsing seconds."""
result = parse_time_duration("30s")
assert result == timedelta(seconds=30)
def test_parse_uppercase(self):
"""Test parsing uppercase units."""
result = parse_time_duration("7D")
assert result == timedelta(days=7)
def test_parse_invalid_format(self):
"""Test parsing invalid format."""
result = parse_time_duration("7days")
assert result is None
result = parse_time_duration("abc")
assert result is None
result = parse_time_duration("7")
assert result is None
def test_parse_empty_string(self):
"""Test parsing empty string."""
result = parse_time_duration("")
assert result is None
def test_parse_none(self):
"""Test parsing None."""
result = parse_time_duration(None)
assert result is None
class TestGetTimeThreshold:
"""Test get_time_threshold function."""
def test_get_threshold_days(self):
"""Test getting threshold for days."""
before = datetime.now(UTC)
result = get_time_threshold("7d")
after = datetime.now(UTC)
assert result is not None
# Result should be approximately 7 days ago
expected = before - timedelta(days=7)
# Allow 1 second tolerance for test execution time
assert abs((result - expected).total_seconds()) < 1
def test_get_threshold_hours(self):
"""Test getting threshold for hours."""
before = datetime.now(UTC)
result = get_time_threshold("4h")
after = datetime.now(UTC)
assert result is not None
expected = before - timedelta(hours=4)
assert abs((result - expected).total_seconds()) < 1
def test_get_threshold_invalid(self):
"""Test getting threshold with invalid duration."""
result = get_time_threshold("invalid")
assert result is None
def test_get_threshold_none(self):
"""Test getting threshold with None."""
result = get_time_threshold(None)
assert result is None

View File

@@ -0,0 +1,62 @@
from unittest.mock import MagicMock
from werkzeug.wrappers import Response
from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_WEBAPP_ACCESS_TOKEN
from libs import token
from libs.token import extract_access_token, extract_webapp_access_token, set_csrf_token_to_cookie
class MockRequest:
def __init__(self, headers: dict[str, str], cookies: dict[str, str], args: dict[str, str]):
self.headers: dict[str, str] = headers
self.cookies: dict[str, str] = cookies
self.args: dict[str, str] = args
def test_extract_access_token():
def _mock_request(headers: dict[str, str], cookies: dict[str, str], args: dict[str, str]):
return MockRequest(headers, cookies, args)
test_cases = [
(_mock_request({"Authorization": "Bearer 123"}, {}, {}), "123", "123"),
(_mock_request({}, {COOKIE_NAME_ACCESS_TOKEN: "123"}, {}), "123", None),
(_mock_request({}, {}, {}), None, None),
(_mock_request({"Authorization": "Bearer_aaa 123"}, {}, {}), None, None),
(_mock_request({}, {COOKIE_NAME_WEBAPP_ACCESS_TOKEN: "123"}, {}), None, "123"),
]
for request, expected_console, expected_webapp in test_cases:
assert extract_access_token(request) == expected_console # pyright: ignore[reportArgumentType]
assert extract_webapp_access_token(request) == expected_webapp # pyright: ignore[reportArgumentType]
def test_real_cookie_name_uses_host_prefix_without_domain(monkeypatch):
monkeypatch.setattr(token.dify_config, "CONSOLE_WEB_URL", "https://console.example.com", raising=False)
monkeypatch.setattr(token.dify_config, "CONSOLE_API_URL", "https://api.example.com", raising=False)
monkeypatch.setattr(token.dify_config, "COOKIE_DOMAIN", "", raising=False)
assert token._real_cookie_name("csrf_token") == "__Host-csrf_token"
def test_real_cookie_name_without_host_prefix_when_domain_present(monkeypatch):
monkeypatch.setattr(token.dify_config, "CONSOLE_WEB_URL", "https://console.example.com", raising=False)
monkeypatch.setattr(token.dify_config, "CONSOLE_API_URL", "https://api.example.com", raising=False)
monkeypatch.setattr(token.dify_config, "COOKIE_DOMAIN", ".example.com", raising=False)
assert token._real_cookie_name("csrf_token") == "csrf_token"
def test_set_csrf_cookie_includes_domain_when_configured(monkeypatch):
monkeypatch.setattr(token.dify_config, "CONSOLE_WEB_URL", "https://console.example.com", raising=False)
monkeypatch.setattr(token.dify_config, "CONSOLE_API_URL", "https://api.example.com", raising=False)
monkeypatch.setattr(token.dify_config, "COOKIE_DOMAIN", ".example.com", raising=False)
response = Response()
request = MagicMock()
set_csrf_token_to_cookie(request, response, "abc123")
cookies = response.headers.getlist("Set-Cookie")
assert any("csrf_token=abc123" in c for c in cookies)
assert any("Domain=example.com" in c for c in cookies)
assert all("__Host-" not in c for c in cookies)

View File

@@ -0,0 +1,351 @@
import struct
import time
import uuid
from unittest import mock
import pytest
from hypothesis import given
from hypothesis import strategies as st
from libs.uuid_utils import _create_uuidv7_bytes, uuidv7, uuidv7_boundary, uuidv7_timestamp
# Tests for private helper function _create_uuidv7_bytes
def test_create_uuidv7_bytes_basic_structure():
"""Test basic byte structure creation."""
timestamp_ms = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
random_bytes = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x11\x22"
result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
# Should be exactly 16 bytes
assert len(result) == 16
assert isinstance(result, bytes)
# Create UUID from bytes to verify it's valid
uuid_obj = uuid.UUID(bytes=result)
assert uuid_obj.version == 7
def test_create_uuidv7_bytes_timestamp_encoding():
"""Test timestamp is correctly encoded in first 48 bits."""
timestamp_ms = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
random_bytes = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
# Extract timestamp from first 6 bytes
timestamp_bytes = b"\x00\x00" + result[0:6]
extracted_timestamp = struct.unpack(">Q", timestamp_bytes)[0]
assert extracted_timestamp == timestamp_ms
def test_create_uuidv7_bytes_version_bits():
"""Test version bits are set to 7."""
timestamp_ms = 1609459200000
random_bytes = b"\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00" # Set first 2 bytes to all 1s
result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
# Extract version from bytes 6-7
version_and_rand_a = struct.unpack(">H", result[6:8])[0]
version = (version_and_rand_a >> 12) & 0x0F
assert version == 7
def test_create_uuidv7_bytes_variant_bits():
"""Test variant bits are set correctly."""
timestamp_ms = 1609459200000
random_bytes = b"\x00\x00\xff\x00\x00\x00\x00\x00\x00\x00" # Set byte 8 to all 1s
result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
# Check variant bits in byte 8 (should be 10xxxxxx)
variant_byte = result[8]
variant_bits = (variant_byte >> 6) & 0b11
assert variant_bits == 0b10 # Should be binary 10
def test_create_uuidv7_bytes_random_data():
"""Test random bytes are placed correctly."""
timestamp_ms = 1609459200000
random_bytes = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x11\x22"
result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
# Check random data A (12 bits from bytes 6-7, excluding version)
version_and_rand_a = struct.unpack(">H", result[6:8])[0]
rand_a = version_and_rand_a & 0x0FFF
expected_rand_a = struct.unpack(">H", random_bytes[0:2])[0] & 0x0FFF
assert rand_a == expected_rand_a
# Check random data B (bytes 8-15, with variant bits preserved)
# Byte 8 should have variant bits set but preserve lower 6 bits
expected_byte_8 = (random_bytes[2] & 0x3F) | 0x80
assert result[8] == expected_byte_8
# Bytes 9-15 should match random_bytes[3:10]
assert result[9:16] == random_bytes[3:10]
def test_create_uuidv7_bytes_zero_random():
"""Test with zero random bytes (boundary case)."""
timestamp_ms = 1609459200000
zero_random_bytes = b"\x00" * 10
result = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes)
# Should still be valid UUIDv7
uuid_obj = uuid.UUID(bytes=result)
assert uuid_obj.version == 7
# Version bits should be 0x7000
version_and_rand_a = struct.unpack(">H", result[6:8])[0]
assert version_and_rand_a == 0x7000
# Variant byte should be 0x80 (variant bits + zero random bits)
assert result[8] == 0x80
# Remaining bytes should be zero
assert result[9:16] == b"\x00" * 7
def test_uuidv7_basic_generation():
"""Test basic UUID generation produces valid UUIDv7."""
result = uuidv7()
# Should be a UUID object
assert isinstance(result, uuid.UUID)
# Should be version 7
assert result.version == 7
# Should have correct variant (RFC 4122 variant)
# Variant bits should be 10xxxxxx (0x80-0xBF range)
variant_byte = result.bytes[8]
assert (variant_byte >> 6) == 0b10
def test_uuidv7_with_custom_timestamp():
"""Test UUID generation with custom timestamp."""
custom_timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
result = uuidv7(custom_timestamp)
assert isinstance(result, uuid.UUID)
assert result.version == 7
# Extract and verify timestamp
extracted_timestamp = uuidv7_timestamp(result)
assert isinstance(extracted_timestamp, int)
assert extracted_timestamp == custom_timestamp # Exact match for integer milliseconds
def test_uuidv7_with_none_timestamp(monkeypatch: pytest.MonkeyPatch):
"""Test UUID generation with None timestamp uses current time."""
mock_time = 1609459200
mock_time_func = mock.Mock(return_value=mock_time)
monkeypatch.setattr("time.time", mock_time_func)
result = uuidv7(None)
assert isinstance(result, uuid.UUID)
assert result.version == 7
# Should use the mocked current time (converted to milliseconds)
assert mock_time_func.called
extracted_timestamp = uuidv7_timestamp(result)
assert extracted_timestamp == mock_time * 1000 # 1609459200.0 * 1000
def test_uuidv7_time_ordering():
"""Test that sequential UUIDs have increasing timestamps."""
# Generate UUIDs with incrementing timestamps (in milliseconds)
timestamp1 = 1609459200000 # 2021-01-01 00:00:00 UTC
timestamp2 = 1609459201000 # 2021-01-01 00:00:01 UTC
timestamp3 = 1609459202000 # 2021-01-01 00:00:02 UTC
uuid1 = uuidv7(timestamp1)
uuid2 = uuidv7(timestamp2)
uuid3 = uuidv7(timestamp3)
# Extract timestamps
ts1 = uuidv7_timestamp(uuid1)
ts2 = uuidv7_timestamp(uuid2)
ts3 = uuidv7_timestamp(uuid3)
# Should be in ascending order
assert ts1 < ts2 < ts3
# UUIDs should be lexicographically ordered by their string representation
# due to time-ordering property of UUIDv7
uuid_strings = [str(uuid1), str(uuid2), str(uuid3)]
assert uuid_strings == sorted(uuid_strings)
def test_uuidv7_uniqueness():
"""Test that multiple calls generate different UUIDs."""
# Generate multiple UUIDs with the same timestamp (in milliseconds)
timestamp = 1609459200000
uuids = [uuidv7(timestamp) for _ in range(100)]
# All should be unique despite same timestamp (due to random bits)
assert len(set(uuids)) == 100
# All should have the same extracted timestamp
for uuid_obj in uuids:
extracted_ts = uuidv7_timestamp(uuid_obj)
assert extracted_ts == timestamp
def test_uuidv7_timestamp_error_handling_wrong_version():
"""Test error handling for non-UUIDv7 inputs."""
uuid_v4 = uuid.uuid4()
with pytest.raises(ValueError) as exc_ctx:
uuidv7_timestamp(uuid_v4)
assert "Expected UUIDv7 (version 7)" in str(exc_ctx.value)
assert f"got version {uuid_v4.version}" in str(exc_ctx.value)
@given(st.integers(max_value=2**48 - 1, min_value=0))
def test_uuidv7_timestamp_round_trip(timestamp_ms):
# Generate UUID with timestamp
uuid_obj = uuidv7(timestamp_ms)
# Extract timestamp back
extracted_timestamp = uuidv7_timestamp(uuid_obj)
# Should match exactly for integer millisecond timestamps
assert extracted_timestamp == timestamp_ms
def test_uuidv7_timestamp_edge_cases():
"""Test timestamp extraction with edge case values."""
# Test with very small timestamp
small_timestamp = 1 # 1ms after epoch
uuid_small = uuidv7(small_timestamp)
extracted_small = uuidv7_timestamp(uuid_small)
assert extracted_small == small_timestamp
# Test with large timestamp (year 2038+)
large_timestamp = 2147483647000 # 2038-01-19 03:14:07 UTC in milliseconds
uuid_large = uuidv7(large_timestamp)
extracted_large = uuidv7_timestamp(uuid_large)
assert extracted_large == large_timestamp
def test_uuidv7_boundary_basic_generation():
"""Test basic boundary UUID generation with a known timestamp."""
timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
result = uuidv7_boundary(timestamp)
# Should be a UUID object
assert isinstance(result, uuid.UUID)
# Should be version 7
assert result.version == 7
# Should have correct variant (RFC 4122 variant)
# Variant bits should be 10xxxxxx (0x80-0xBF range)
variant_byte = result.bytes[8]
assert (variant_byte >> 6) == 0b10
def test_uuidv7_boundary_timestamp_extraction():
"""Test that boundary UUID timestamp can be extracted correctly."""
timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
boundary_uuid = uuidv7_boundary(timestamp)
# Extract timestamp using existing function
extracted_timestamp = uuidv7_timestamp(boundary_uuid)
# Should match exactly
assert extracted_timestamp == timestamp
def test_uuidv7_boundary_deterministic():
"""Test that boundary UUIDs are deterministic for same timestamp."""
timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
# Generate multiple boundary UUIDs with same timestamp
uuid1 = uuidv7_boundary(timestamp)
uuid2 = uuidv7_boundary(timestamp)
uuid3 = uuidv7_boundary(timestamp)
# Should all be identical
assert uuid1 == uuid2 == uuid3
assert str(uuid1) == str(uuid2) == str(uuid3)
def test_uuidv7_boundary_is_minimum():
"""Test that boundary UUID is lexicographically smaller than regular UUIDs."""
timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
# Generate boundary UUID
boundary_uuid = uuidv7_boundary(timestamp)
# Generate multiple regular UUIDs with same timestamp
regular_uuids = [uuidv7(timestamp) for _ in range(50)]
# Boundary UUID should be lexicographically smaller than all regular UUIDs
boundary_str = str(boundary_uuid)
for regular_uuid in regular_uuids:
regular_str = str(regular_uuid)
assert boundary_str < regular_str, f"Boundary {boundary_str} should be < regular {regular_str}"
# Also test with bytes comparison
boundary_bytes = boundary_uuid.bytes
for regular_uuid in regular_uuids:
regular_bytes = regular_uuid.bytes
assert boundary_bytes < regular_bytes
def test_uuidv7_boundary_different_timestamps():
"""Test that boundary UUIDs with different timestamps are ordered correctly."""
timestamp1 = 1609459200000 # 2021-01-01 00:00:00 UTC
timestamp2 = 1609459201000 # 2021-01-01 00:00:01 UTC
timestamp3 = 1609459202000 # 2021-01-01 00:00:02 UTC
uuid1 = uuidv7_boundary(timestamp1)
uuid2 = uuidv7_boundary(timestamp2)
uuid3 = uuidv7_boundary(timestamp3)
# Extract timestamps to verify
ts1 = uuidv7_timestamp(uuid1)
ts2 = uuidv7_timestamp(uuid2)
ts3 = uuidv7_timestamp(uuid3)
# Should be in ascending order
assert ts1 < ts2 < ts3
# UUIDs should be lexicographically ordered
uuid_strings = [str(uuid1), str(uuid2), str(uuid3)]
assert uuid_strings == sorted(uuid_strings)
# Bytes should also be ordered
assert uuid1.bytes < uuid2.bytes < uuid3.bytes
def test_uuidv7_boundary_edge_cases():
"""Test boundary UUID generation with edge case timestamp values."""
# Test with timestamp 0 (Unix epoch)
epoch_uuid = uuidv7_boundary(0)
assert isinstance(epoch_uuid, uuid.UUID)
assert epoch_uuid.version == 7
assert uuidv7_timestamp(epoch_uuid) == 0
# Test with very large timestamp values
large_timestamp = 2147483647000 # 2038-01-19 03:14:07 UTC in milliseconds
large_uuid = uuidv7_boundary(large_timestamp)
assert isinstance(large_uuid, uuid.UUID)
assert large_uuid.version == 7
assert uuidv7_timestamp(large_uuid) == large_timestamp
# Test with current time
current_time = int(time.time() * 1000)
current_uuid = uuidv7_boundary(current_time)
assert isinstance(current_uuid, uuid.UUID)
assert current_uuid.version == 7
assert uuidv7_timestamp(current_uuid) == current_time

View File

@@ -0,0 +1,29 @@
import pytest
from yarl import URL
def test_yarl_urls():
expected_1 = "https://dify.ai/api"
assert str(URL("https://dify.ai") / "api") == expected_1
assert str(URL("https://dify.ai/") / "api") == expected_1
expected_2 = "http://dify.ai:12345/api"
assert str(URL("http://dify.ai:12345") / "api") == expected_2
assert str(URL("http://dify.ai:12345/") / "api") == expected_2
expected_3 = "https://dify.ai/api/v1"
assert str(URL("https://dify.ai") / "api" / "v1") == expected_3
assert str(URL("https://dify.ai") / "api/v1") == expected_3
assert str(URL("https://dify.ai/") / "api/v1") == expected_3
assert str(URL("https://dify.ai/api") / "v1") == expected_3
assert str(URL("https://dify.ai/api/") / "v1") == expected_3
expected_4 = "api"
assert str(URL("") / "api") == expected_4
expected_5 = "/api"
assert str(URL("/") / "api") == expected_5
with pytest.raises(ValueError) as e1:
str(URL("https://dify.ai") / "/api")
assert str(e1.value) == "Appending path '/api' starting from slash is forbidden"