dify
This commit is contained in:
106
dify/api/services/rag_pipeline/rag_pipeline_task_proxy.py
Normal file
106
dify/api/services/rag_pipeline/rag_pipeline_task_proxy.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import cached_property
|
||||
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RagPipelineTaskProxy:
|
||||
# Default uploaded file name for rag pipeline invoke entities
|
||||
_RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME = "rag_pipeline_invoke_entities.json"
|
||||
|
||||
def __init__(
|
||||
self, dataset_tenant_id: str, user_id: str, rag_pipeline_invoke_entities: Sequence[RagPipelineInvokeEntity]
|
||||
):
|
||||
self._dataset_tenant_id = dataset_tenant_id
|
||||
self._user_id = user_id
|
||||
self._rag_pipeline_invoke_entities = rag_pipeline_invoke_entities
|
||||
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(dataset_tenant_id, "pipeline")
|
||||
|
||||
@cached_property
|
||||
def features(self):
|
||||
return FeatureService.get_features(self._dataset_tenant_id)
|
||||
|
||||
def _upload_invoke_entities(self) -> str:
|
||||
text = [item.model_dump() for item in self._rag_pipeline_invoke_entities]
|
||||
# Convert list to proper JSON string
|
||||
json_text = json.dumps(text)
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id
|
||||
)
|
||||
return upload_file.id
|
||||
|
||||
def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
|
||||
logger.info("send file %s to direct queue", upload_file_id)
|
||||
task_func.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file_id,
|
||||
tenant_id=self._dataset_tenant_id,
|
||||
)
|
||||
|
||||
def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
|
||||
logger.info("send file %s to tenant queue", upload_file_id)
|
||||
if self._tenant_isolated_task_queue.get_task_key():
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
self._tenant_isolated_task_queue.push_tasks([upload_file_id])
|
||||
logger.info("push tasks: %s", upload_file_id)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
self._tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file_id,
|
||||
tenant_id=self._dataset_tenant_id,
|
||||
)
|
||||
logger.info("init tasks: %s", upload_file_id)
|
||||
|
||||
def _send_to_default_tenant_queue(self, upload_file_id: str):
|
||||
self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task)
|
||||
|
||||
def _send_to_priority_tenant_queue(self, upload_file_id: str):
|
||||
self._send_to_tenant_queue(upload_file_id, priority_rag_pipeline_run_task)
|
||||
|
||||
def _send_to_priority_direct_queue(self, upload_file_id: str):
|
||||
self._send_to_direct_queue(upload_file_id, priority_rag_pipeline_run_task)
|
||||
|
||||
def _dispatch(self):
|
||||
upload_file_id = self._upload_invoke_entities()
|
||||
if not upload_file_id:
|
||||
raise ValueError("upload_file_id is empty")
|
||||
|
||||
logger.info(
|
||||
"dispatch args: %s - %s - %s",
|
||||
self._dataset_tenant_id,
|
||||
self.features.billing.enabled,
|
||||
self.features.billing.subscription.plan,
|
||||
)
|
||||
|
||||
# dispatch to different pipeline queue with tenant isolation when billing enabled
|
||||
if self.features.billing.enabled:
|
||||
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
# dispatch to normal pipeline queue with tenant isolation for sandbox plan
|
||||
self._send_to_default_tenant_queue(upload_file_id)
|
||||
else:
|
||||
# dispatch to priority pipeline queue with tenant isolation for other plans
|
||||
self._send_to_priority_tenant_queue(upload_file_id)
|
||||
else:
|
||||
# dispatch to priority pipeline queue without tenant isolation for others, e.g.: self-hosted or enterprise
|
||||
self._send_to_priority_direct_queue(upload_file_id)
|
||||
|
||||
def delay(self):
|
||||
if not self._rag_pipeline_invoke_entities:
|
||||
logger.warning(
|
||||
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s",
|
||||
self._dataset_tenant_id,
|
||||
self._user_id,
|
||||
)
|
||||
return
|
||||
self._dispatch()
|
||||
Reference in New Issue
Block a user