This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

View File

@@ -0,0 +1,92 @@
import logging
from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.dataset_service import DatasetCollectionBindingService
logger = logging.getLogger(__name__)
class AnnotationReplyFeature:
def query(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> MessageAnnotation | None:
"""
Query app annotations to reply
:param app_record: app record
:param message: message
:param query: query
:param user_id: user id
:param invoke_from: invoke from
:return:
"""
stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id)
annotation_setting = db.session.scalar(stmt)
if not annotation_setting:
return None
collection_binding_detail = annotation_setting.collection_binding_detail
if not collection_binding_detail:
return None
try:
score_threshold = annotation_setting.score_threshold or 1
embedding_provider_name = collection_binding_detail.provider_name
embedding_model_name = collection_binding_detail.model_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation"
)
dataset = Dataset(
id=app_record.id,
tenant_id=app_record.tenant_id,
indexing_technique="high_quality",
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id,
)
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
documents = vector.search_by_vector(
query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]}
)
if documents and documents[0].metadata:
annotation_id = documents[0].metadata["annotation_id"]
score = documents[0].metadata["score"]
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation:
if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}:
from_source = "api"
else:
from_source = "console"
# insert annotation history
AppAnnotationService.add_annotation_history(
annotation.id,
app_record.id,
annotation.question,
annotation.content,
query,
user_id,
message.id,
from_source,
score,
)
return annotation
except Exception as e:
logger.warning("Query annotation failed, exception: %s.", str(e))
return None
return None

View File

@@ -0,0 +1,31 @@
import logging
from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity
from core.helper import moderation
from core.model_runtime.entities.message_entities import PromptMessage
logger = logging.getLogger(__name__)
class HostingModerationFeature:
def check(
self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage]
) -> bool:
"""
Check hosting moderation
:param application_generate_entity: application generate entity
:param prompt_messages: prompt messages
:return:
"""
model_config = application_generate_entity.model_conf
text = ""
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, str):
text += prompt_message.content + "\n"
moderation_result = moderation.check_moderation(
tenant_id=application_generate_entity.app_config.tenant_id, model_config=model_config, text=text
)
return moderation_result

View File

@@ -0,0 +1,3 @@
from .rate_limit import RateLimit
__all__ = ["RateLimit"]

View File

@@ -0,0 +1,130 @@
import logging
import time
import uuid
from collections.abc import Generator, Mapping
from datetime import timedelta
from typing import Any, Union
from core.errors.error import AppInvokeQuotaExceededError
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class RateLimit:
_MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests"
_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests"
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict: dict[str, "RateLimit"] = {}
def __new__(cls, client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
instance = super().__new__(cls)
cls._instance_dict[client_id] = instance
return cls._instance_dict[client_id]
def __init__(self, client_id: str, max_active_requests: int):
self.max_active_requests = max_active_requests
# must be called after max_active_requests is set
if self.disabled():
return
if hasattr(self, "initialized"):
return
self.initialized = True
self.client_id = client_id
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
self.last_recalculate_time = float("-inf")
self.flush_cache(use_local_value=True)
def flush_cache(self, use_local_value=False):
if self.disabled():
return
self.last_recalculate_time = time.time()
# flush max active requests
if use_local_value or not redis_client.exists(self.max_active_requests_key):
redis_client.setex(self.max_active_requests_key, timedelta(days=1), self.max_active_requests)
else:
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8"))
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
# flush max active requests (in-transit request list)
if not redis_client.exists(self.active_requests_key):
return
request_details = redis_client.hgetall(self.active_requests_key)
redis_client.expire(self.active_requests_key, timedelta(days=1))
timeout_requests = [
k
for k, v in request_details.items()
if time.time() - float(v.decode("utf-8")) > RateLimit._REQUEST_MAX_ALIVE_TIME
]
if timeout_requests:
redis_client.hdel(self.active_requests_key, *timeout_requests)
def enter(self, request_id: str | None = None) -> str:
if self.disabled():
return RateLimit._UNLIMITED_REQUEST_ID
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
self.flush_cache()
if not request_id:
request_id = RateLimit.gen_request_key()
active_requests_count = redis_client.hlen(self.active_requests_key)
if active_requests_count >= self.max_active_requests:
raise AppInvokeQuotaExceededError(
f"Too many requests. Please try again later. The current maximum concurrent requests allowed "
f"for {self.client_id} is {self.max_active_requests}."
)
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
return request_id
def exit(self, request_id: str):
if request_id == RateLimit._UNLIMITED_REQUEST_ID:
return
redis_client.hdel(self.active_requests_key, request_id)
def disabled(self):
return self.max_active_requests <= 0
@staticmethod
def gen_request_key() -> str:
return str(uuid.uuid4())
def generate(self, generator: Union[Generator[str, None, None], Mapping[str, Any]], request_id: str):
if isinstance(generator, Mapping):
return generator
else:
return RateLimitGenerator(
rate_limit=self,
generator=generator,
request_id=request_id,
)
class RateLimitGenerator:
def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str):
self.rate_limit = rate_limit
self.generator = generator
self.request_id = request_id
self.closed = False
def __iter__(self):
return self
def __next__(self):
if self.closed:
raise StopIteration
try:
return next(self.generator)
except Exception:
self.close()
raise
def close(self):
if not self.closed:
self.closed = True
self.rate_limit.exit(self.request_id)
if self.generator is not None and hasattr(self.generator, "close"):
self.generator.close()