107 lines
4.6 KiB
Python
107 lines
4.6 KiB
Python
|
|
import json
|
||
|
|
import logging
|
||
|
|
from collections.abc import Callable, Sequence
|
||
|
|
from functools import cached_property
|
||
|
|
|
||
|
|
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||
|
|
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||
|
|
from enums.cloud_plan import CloudPlan
|
||
|
|
from extensions.ext_database import db
|
||
|
|
from services.feature_service import FeatureService
|
||
|
|
from services.file_service import FileService
|
||
|
|
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
|
||
|
|
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class RagPipelineTaskProxy:
|
||
|
|
# Default uploaded file name for rag pipeline invoke entities
|
||
|
|
_RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME = "rag_pipeline_invoke_entities.json"
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self, dataset_tenant_id: str, user_id: str, rag_pipeline_invoke_entities: Sequence[RagPipelineInvokeEntity]
|
||
|
|
):
|
||
|
|
self._dataset_tenant_id = dataset_tenant_id
|
||
|
|
self._user_id = user_id
|
||
|
|
self._rag_pipeline_invoke_entities = rag_pipeline_invoke_entities
|
||
|
|
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(dataset_tenant_id, "pipeline")
|
||
|
|
|
||
|
|
@cached_property
|
||
|
|
def features(self):
|
||
|
|
return FeatureService.get_features(self._dataset_tenant_id)
|
||
|
|
|
||
|
|
def _upload_invoke_entities(self) -> str:
|
||
|
|
text = [item.model_dump() for item in self._rag_pipeline_invoke_entities]
|
||
|
|
# Convert list to proper JSON string
|
||
|
|
json_text = json.dumps(text)
|
||
|
|
upload_file = FileService(db.engine).upload_text(
|
||
|
|
json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id
|
||
|
|
)
|
||
|
|
return upload_file.id
|
||
|
|
|
||
|
|
def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
|
||
|
|
logger.info("send file %s to direct queue", upload_file_id)
|
||
|
|
task_func.delay( # type: ignore
|
||
|
|
rag_pipeline_invoke_entities_file_id=upload_file_id,
|
||
|
|
tenant_id=self._dataset_tenant_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
|
||
|
|
logger.info("send file %s to tenant queue", upload_file_id)
|
||
|
|
if self._tenant_isolated_task_queue.get_task_key():
|
||
|
|
# Add to waiting queue using List operations (lpush)
|
||
|
|
self._tenant_isolated_task_queue.push_tasks([upload_file_id])
|
||
|
|
logger.info("push tasks: %s", upload_file_id)
|
||
|
|
else:
|
||
|
|
# Set flag and execute task
|
||
|
|
self._tenant_isolated_task_queue.set_task_waiting_time()
|
||
|
|
task_func.delay( # type: ignore
|
||
|
|
rag_pipeline_invoke_entities_file_id=upload_file_id,
|
||
|
|
tenant_id=self._dataset_tenant_id,
|
||
|
|
)
|
||
|
|
logger.info("init tasks: %s", upload_file_id)
|
||
|
|
|
||
|
|
def _send_to_default_tenant_queue(self, upload_file_id: str):
|
||
|
|
self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task)
|
||
|
|
|
||
|
|
def _send_to_priority_tenant_queue(self, upload_file_id: str):
|
||
|
|
self._send_to_tenant_queue(upload_file_id, priority_rag_pipeline_run_task)
|
||
|
|
|
||
|
|
def _send_to_priority_direct_queue(self, upload_file_id: str):
|
||
|
|
self._send_to_direct_queue(upload_file_id, priority_rag_pipeline_run_task)
|
||
|
|
|
||
|
|
def _dispatch(self):
|
||
|
|
upload_file_id = self._upload_invoke_entities()
|
||
|
|
if not upload_file_id:
|
||
|
|
raise ValueError("upload_file_id is empty")
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
"dispatch args: %s - %s - %s",
|
||
|
|
self._dataset_tenant_id,
|
||
|
|
self.features.billing.enabled,
|
||
|
|
self.features.billing.subscription.plan,
|
||
|
|
)
|
||
|
|
|
||
|
|
# dispatch to different pipeline queue with tenant isolation when billing enabled
|
||
|
|
if self.features.billing.enabled:
|
||
|
|
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||
|
|
# dispatch to normal pipeline queue with tenant isolation for sandbox plan
|
||
|
|
self._send_to_default_tenant_queue(upload_file_id)
|
||
|
|
else:
|
||
|
|
# dispatch to priority pipeline queue with tenant isolation for other plans
|
||
|
|
self._send_to_priority_tenant_queue(upload_file_id)
|
||
|
|
else:
|
||
|
|
# dispatch to priority pipeline queue without tenant isolation for others, e.g.: self-hosted or enterprise
|
||
|
|
self._send_to_priority_direct_queue(upload_file_id)
|
||
|
|
|
||
|
|
def delay(self):
|
||
|
|
if not self._rag_pipeline_invoke_entities:
|
||
|
|
logger.warning(
|
||
|
|
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s",
|
||
|
|
self._dataset_tenant_id,
|
||
|
|
self._user_id,
|
||
|
|
)
|
||
|
|
return
|
||
|
|
self._dispatch()
|