83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
|
|
import json
|
||
|
|
from collections.abc import Sequence
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from pydantic import BaseModel, ValidationError
|
||
|
|
|
||
|
|
from extensions.ext_redis import redis_client
|
||
|
|
|
||
|
|
_DEFAULT_TASK_TTL = 60 * 60 # 1 hour
|
||
|
|
|
||
|
|
|
||
|
|
class TaskWrapper(BaseModel):
|
||
|
|
data: Any
|
||
|
|
|
||
|
|
def serialize(self) -> str:
|
||
|
|
return self.model_dump_json()
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def deserialize(cls, serialized_data: str) -> "TaskWrapper":
|
||
|
|
return cls.model_validate_json(serialized_data)
|
||
|
|
|
||
|
|
|
||
|
|
class TenantIsolatedTaskQueue:
|
||
|
|
"""
|
||
|
|
Simple queue for tenant isolated tasks, used for rag related tenant tasks isolation.
|
||
|
|
It uses Redis list to store tasks, and Redis key to store task waiting flag.
|
||
|
|
Support tasks that can be serialized by json.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, tenant_id: str, unique_key: str):
|
||
|
|
self._tenant_id = tenant_id
|
||
|
|
self._unique_key = unique_key
|
||
|
|
self._queue = f"tenant_self_{unique_key}_task_queue:{tenant_id}"
|
||
|
|
self._task_key = f"tenant_{unique_key}_task:{tenant_id}"
|
||
|
|
|
||
|
|
def get_task_key(self):
|
||
|
|
return redis_client.get(self._task_key)
|
||
|
|
|
||
|
|
def set_task_waiting_time(self, ttl: int = _DEFAULT_TASK_TTL):
|
||
|
|
redis_client.setex(self._task_key, ttl, 1)
|
||
|
|
|
||
|
|
def delete_task_key(self):
|
||
|
|
redis_client.delete(self._task_key)
|
||
|
|
|
||
|
|
def push_tasks(self, tasks: Sequence[Any]):
|
||
|
|
serialized_tasks = []
|
||
|
|
for task in tasks:
|
||
|
|
# Store str list directly, maintaining full compatibility for pipeline scenarios
|
||
|
|
if isinstance(task, str):
|
||
|
|
serialized_tasks.append(task)
|
||
|
|
else:
|
||
|
|
# Use TaskWrapper to do JSON serialization for non-string tasks
|
||
|
|
wrapper = TaskWrapper(data=task)
|
||
|
|
serialized_data = wrapper.serialize()
|
||
|
|
serialized_tasks.append(serialized_data)
|
||
|
|
|
||
|
|
if not serialized_tasks:
|
||
|
|
return
|
||
|
|
|
||
|
|
redis_client.lpush(self._queue, *serialized_tasks)
|
||
|
|
|
||
|
|
def pull_tasks(self, count: int = 1) -> Sequence[Any]:
|
||
|
|
if count <= 0:
|
||
|
|
return []
|
||
|
|
|
||
|
|
tasks = []
|
||
|
|
for _ in range(count):
|
||
|
|
serialized_task = redis_client.rpop(self._queue)
|
||
|
|
if not serialized_task:
|
||
|
|
break
|
||
|
|
|
||
|
|
if isinstance(serialized_task, bytes):
|
||
|
|
serialized_task = serialized_task.decode("utf-8")
|
||
|
|
|
||
|
|
try:
|
||
|
|
wrapper = TaskWrapper.deserialize(serialized_task)
|
||
|
|
tasks.append(wrapper.data)
|
||
|
|
except (json.JSONDecodeError, ValidationError, TypeError, ValueError):
|
||
|
|
# Fall back to raw string for legacy format or invalid JSON
|
||
|
|
tasks.append(serialized_task)
|
||
|
|
|
||
|
|
return tasks
|