dify
This commit is contained in:
0
dify/api/core/moderation/__init__.py
Normal file
0
dify/api/core/moderation/__init__.py
Normal file
1
dify/api/core/moderation/api/__builtin__
Normal file
1
dify/api/core/moderation/api/__builtin__
Normal file
@@ -0,0 +1 @@
|
||||
3
|
||||
0
dify/api/core/moderation/api/__init__.py
Normal file
0
dify/api/core/moderation/api/__init__.py
Normal file
94
dify/api/core/moderation/api/api.py
Normal file
94
dify/api/core/moderation/api/api.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
|
||||
from core.helper.encrypter import decrypt_token
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
from extensions.ext_database import db
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
|
||||
|
||||
class ModerationInputParams(BaseModel):
|
||||
app_id: str = ""
|
||||
inputs: dict = Field(default_factory=dict)
|
||||
query: str = ""
|
||||
|
||||
|
||||
class ModerationOutputParams(BaseModel):
|
||||
app_id: str = ""
|
||||
text: str
|
||||
|
||||
|
||||
class ApiModeration(Moderation):
|
||||
name: str = "api"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
cls._validate_inputs_and_outputs_config(config, False)
|
||||
|
||||
api_based_extension_id = config.get("api_based_extension_id")
|
||||
if not api_based_extension_id:
|
||||
raise ValueError("api_based_extension_id is required")
|
||||
|
||||
extension = cls._get_api_based_extension(tenant_id, api_based_extension_id)
|
||||
if not extension:
|
||||
raise ValueError("API-based Extension not found. Please check it again.")
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
|
||||
if self.config["inputs_config"]["enabled"]:
|
||||
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
|
||||
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
|
||||
return ModerationInputsResult.model_validate(result)
|
||||
|
||||
return ModerationInputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
|
||||
if self.config["outputs_config"]["enabled"]:
|
||||
params = ModerationOutputParams(app_id=self.app_id, text=text)
|
||||
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
|
||||
return ModerationOutputsResult.model_validate(result)
|
||||
|
||||
return ModerationOutputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict):
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))
|
||||
if not extension:
|
||||
raise ValueError("API-based Extension not found. Please check it again.")
|
||||
requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key))
|
||||
|
||||
result = requestor.request(extension_point, params)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension | None:
|
||||
stmt = select(APIBasedExtension).where(
|
||||
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
|
||||
)
|
||||
extension = db.session.scalar(stmt)
|
||||
|
||||
return extension
|
||||
114
dify/api/core/moderation/base.py
Normal file
114
dify/api/core/moderation/base.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.extension.extensible import Extensible, ExtensionModule
|
||||
|
||||
|
||||
class ModerationAction(StrEnum):
|
||||
DIRECT_OUTPUT = auto()
|
||||
OVERRIDDEN = auto()
|
||||
|
||||
|
||||
class ModerationInputsResult(BaseModel):
|
||||
flagged: bool = False
|
||||
action: ModerationAction
|
||||
preset_response: str = ""
|
||||
inputs: dict = Field(default_factory=dict)
|
||||
query: str = ""
|
||||
|
||||
|
||||
class ModerationOutputsResult(BaseModel):
|
||||
flagged: bool = False
|
||||
action: ModerationAction
|
||||
preset_response: str = ""
|
||||
text: str = ""
|
||||
|
||||
|
||||
class Moderation(Extensible, ABC):
|
||||
"""
|
||||
The base class of moderation.
|
||||
"""
|
||||
|
||||
module: ExtensionModule = ExtensionModule.MODERATION
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, config: dict | None = None):
|
||||
super().__init__(tenant_id, config)
|
||||
self.app_id = app_id
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
"""
|
||||
Moderation for inputs.
|
||||
After the user inputs, this method will be called to perform sensitive content review
|
||||
on the user inputs and return the processed results.
|
||||
|
||||
:param inputs: user inputs
|
||||
:param query: query string (required in chat app)
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
"""
|
||||
Moderation for outputs.
|
||||
When LLM outputs content, the front end will pass the output content (may be segmented)
|
||||
to this method for sensitive content review, and the output content will be shielded if the review fails.
|
||||
|
||||
:param text: LLM output content
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool):
|
||||
# inputs_config
|
||||
inputs_config = config.get("inputs_config")
|
||||
if not isinstance(inputs_config, dict):
|
||||
raise ValueError("inputs_config must be a dict")
|
||||
|
||||
# outputs_config
|
||||
outputs_config = config.get("outputs_config")
|
||||
if not isinstance(outputs_config, dict):
|
||||
raise ValueError("outputs_config must be a dict")
|
||||
|
||||
inputs_config_enabled = inputs_config.get("enabled")
|
||||
outputs_config_enabled = outputs_config.get("enabled")
|
||||
if not inputs_config_enabled and not outputs_config_enabled:
|
||||
raise ValueError("At least one of inputs_config or outputs_config must be enabled")
|
||||
|
||||
# preset_response
|
||||
if not is_preset_response_required:
|
||||
return
|
||||
|
||||
if inputs_config_enabled:
|
||||
if not inputs_config.get("preset_response"):
|
||||
raise ValueError("inputs_config.preset_response is required")
|
||||
|
||||
if len(inputs_config.get("preset_response", "0")) > 100:
|
||||
raise ValueError("inputs_config.preset_response must be less than 100 characters")
|
||||
|
||||
if outputs_config_enabled:
|
||||
if not outputs_config.get("preset_response"):
|
||||
raise ValueError("outputs_config.preset_response is required")
|
||||
|
||||
if len(outputs_config.get("preset_response", "0")) > 100:
|
||||
raise ValueError("outputs_config.preset_response must be less than 100 characters")
|
||||
|
||||
|
||||
class ModerationError(Exception):
|
||||
pass
|
||||
48
dify/api/core/moderation/factory.py
Normal file
48
dify/api/core/moderation/factory.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from core.extension.extensible import ExtensionModule
|
||||
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
|
||||
from extensions.ext_code_based_extension import code_based_extension
|
||||
|
||||
|
||||
class ModerationFactory:
|
||||
__extension_instance: Moderation
|
||||
|
||||
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict):
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
||||
self.__extension_instance = extension_class(app_id, tenant_id, config)
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param name: the name of extension
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
||||
# FIXME: mypy error, try to fix it instead of using type: ignore
|
||||
extension_class.validate_config(tenant_id, config) # type: ignore
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
"""
|
||||
Moderation for inputs.
|
||||
After the user inputs, this method will be called to perform sensitive content review
|
||||
on the user inputs and return the processed results.
|
||||
|
||||
:param inputs: user inputs
|
||||
:param query: query string (required in chat app)
|
||||
:return:
|
||||
"""
|
||||
return self.__extension_instance.moderation_for_inputs(inputs, query)
|
||||
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
"""
|
||||
Moderation for outputs.
|
||||
When LLM outputs content, the front end will pass the output content (may be segmented)
|
||||
to this method for sensitive content review, and the output content will be shielded if the review fails.
|
||||
|
||||
:param text: LLM output content
|
||||
:return:
|
||||
"""
|
||||
return self.__extension_instance.moderation_for_outputs(text)
|
||||
71
dify/api/core/moderation/input_moderation.py
Normal file
71
dify/api/core/moderation/input_moderation.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import AppConfig
|
||||
from core.moderation.base import ModerationAction, ModerationError
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InputModeration:
|
||||
def check(
|
||||
self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
app_config: AppConfig,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> tuple[bool, Mapping[str, Any], str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
:param tenant_id: tenant id
|
||||
:param app_config: app config
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:param message_id: message id
|
||||
:param trace_manager: trace manager
|
||||
:return:
|
||||
"""
|
||||
inputs = dict(inputs)
|
||||
if not app_config.sensitive_word_avoidance:
|
||||
return False, inputs, query
|
||||
|
||||
sensitive_word_avoidance_config = app_config.sensitive_word_avoidance
|
||||
moderation_type = sensitive_word_avoidance_config.type
|
||||
|
||||
moderation_factory = ModerationFactory(
|
||||
name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config
|
||||
)
|
||||
|
||||
with measure_time() as timer:
|
||||
moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MODERATION_TRACE,
|
||||
message_id=message_id,
|
||||
moderation_result=moderation_result,
|
||||
inputs=inputs,
|
||||
timer=timer,
|
||||
)
|
||||
)
|
||||
|
||||
if not moderation_result.flagged:
|
||||
return False, inputs, query
|
||||
|
||||
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
raise ModerationError(moderation_result.preset_response)
|
||||
elif moderation_result.action == ModerationAction.OVERRIDDEN:
|
||||
inputs = moderation_result.inputs
|
||||
query = moderation_result.query
|
||||
|
||||
return True, inputs, query
|
||||
1
dify/api/core/moderation/keywords/__builtin__
Normal file
1
dify/api/core/moderation/keywords/__builtin__
Normal file
@@ -0,0 +1 @@
|
||||
2
|
||||
0
dify/api/core/moderation/keywords/__init__.py
Normal file
0
dify/api/core/moderation/keywords/__init__.py
Normal file
73
dify/api/core/moderation/keywords/keywords.py
Normal file
73
dify/api/core/moderation/keywords/keywords.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
|
||||
|
||||
class KeywordsModeration(Moderation):
|
||||
name: str = "keywords"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
cls._validate_inputs_and_outputs_config(config, True)
|
||||
|
||||
if not config.get("keywords"):
|
||||
raise ValueError("keywords is required")
|
||||
|
||||
if len(config.get("keywords", [])) > 10000:
|
||||
raise ValueError("keywords length must be less than 10000")
|
||||
|
||||
keywords_row_len = config["keywords"].split("\n")
|
||||
if len(keywords_row_len) > 100:
|
||||
raise ValueError("the number of rows for the keywords must be less than 100")
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
|
||||
if self.config["inputs_config"]["enabled"]:
|
||||
preset_response = self.config["inputs_config"]["preset_response"]
|
||||
|
||||
if query:
|
||||
inputs["query__"] = query
|
||||
|
||||
# Filter out empty values
|
||||
keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword]
|
||||
|
||||
flagged = self._is_violated(inputs, keywords_list)
|
||||
|
||||
return ModerationInputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
|
||||
if self.config["outputs_config"]["enabled"]:
|
||||
# Filter out empty values
|
||||
keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword]
|
||||
|
||||
flagged = self._is_violated({"text": text}, keywords_list)
|
||||
preset_response = self.config["outputs_config"]["preset_response"]
|
||||
|
||||
return ModerationOutputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
||||
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
|
||||
|
||||
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:
|
||||
return any(keyword.lower() in str(value).lower() for keyword in keywords_list)
|
||||
1
dify/api/core/moderation/openai_moderation/__builtin__
Normal file
1
dify/api/core/moderation/openai_moderation/__builtin__
Normal file
@@ -0,0 +1 @@
|
||||
1
|
||||
@@ -0,0 +1,60 @@
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
|
||||
|
||||
class OpenAIModeration(Moderation):
|
||||
name: str = "openai_moderation"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
cls._validate_inputs_and_outputs_config(config, True)
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
|
||||
if self.config["inputs_config"]["enabled"]:
|
||||
preset_response = self.config["inputs_config"]["preset_response"]
|
||||
|
||||
if query:
|
||||
inputs["query__"] = query
|
||||
flagged = self._is_violated(inputs)
|
||||
|
||||
return ModerationInputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
|
||||
if self.config["outputs_config"]["enabled"]:
|
||||
flagged = self._is_violated({"text": text})
|
||||
preset_response = self.config["outputs_config"]["preset_response"]
|
||||
|
||||
return ModerationOutputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _is_violated(self, inputs: dict):
|
||||
text = "\n".join(str(inputs.values()))
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest"
|
||||
)
|
||||
|
||||
openai_moderation = model_instance.invoke_moderation(text=text)
|
||||
|
||||
return openai_moderation
|
||||
141
dify/api/core/moderation/output_moderation.py
Normal file
141
dify/api/core/moderation/output_moderation.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import QueueMessageReplaceEvent
|
||||
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModerationRule(BaseModel):
|
||||
type: str
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class OutputModeration(BaseModel):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
|
||||
rule: ModerationRule
|
||||
queue_manager: AppQueueManager
|
||||
|
||||
thread: threading.Thread | None = None
|
||||
thread_running: bool = True
|
||||
buffer: str = ""
|
||||
is_final_chunk: bool = False
|
||||
final_output: str | None = None
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def should_direct_output(self) -> bool:
|
||||
return self.final_output is not None
|
||||
|
||||
def get_final_output(self) -> str:
|
||||
return self.final_output or ""
|
||||
|
||||
def append_new_token(self, token: str):
|
||||
self.buffer += token
|
||||
|
||||
if not self.thread:
|
||||
self.thread = self.start_thread()
|
||||
|
||||
def moderation_completion(self, completion: str, public_event: bool = False) -> tuple[str, bool]:
|
||||
self.buffer = completion
|
||||
self.is_final_chunk = True
|
||||
|
||||
result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion)
|
||||
|
||||
if not result or not result.flagged:
|
||||
return completion, False
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
else:
|
||||
final_output = result.text
|
||||
|
||||
if public_event:
|
||||
self.queue_manager.publish(
|
||||
QueueMessageReplaceEvent(
|
||||
text=final_output, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION
|
||||
),
|
||||
PublishFrom.TASK_PIPELINE,
|
||||
)
|
||||
|
||||
return final_output, True
|
||||
|
||||
def start_thread(self) -> threading.Thread:
|
||||
buffer_size = dify_config.MODERATION_BUFFER_SIZE
|
||||
thread = threading.Thread(
|
||||
target=self.worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE,
|
||||
},
|
||||
)
|
||||
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
|
||||
def stop_thread(self):
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread_running = False
|
||||
|
||||
def worker(self, flask_app: Flask, buffer_size: int):
|
||||
with flask_app.app_context():
|
||||
current_length = 0
|
||||
while self.thread_running:
|
||||
moderation_buffer = self.buffer
|
||||
buffer_length = len(moderation_buffer)
|
||||
if not self.is_final_chunk:
|
||||
chunk_length = buffer_length - current_length
|
||||
if 0 <= chunk_length < buffer_size:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
current_length = buffer_length
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=moderation_buffer
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
continue
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
self.final_output = final_output
|
||||
else:
|
||||
final_output = result.text + self.buffer[len(moderation_buffer) :]
|
||||
|
||||
# trigger replace event
|
||||
if self.thread_running:
|
||||
self.queue_manager.publish(
|
||||
QueueMessageReplaceEvent(
|
||||
text=final_output, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION
|
||||
),
|
||||
PublishFrom.TASK_PIPELINE,
|
||||
)
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
break
|
||||
|
||||
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> ModerationOutputsResult | None:
|
||||
try:
|
||||
moderation_factory = ModerationFactory(
|
||||
name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config
|
||||
)
|
||||
|
||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception("Moderation Output error, app_id: %s", app_id)
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user