dify
This commit is contained in:
1
dify/api/core/moderation/openai_moderation/__builtin__
Normal file
1
dify/api/core/moderation/openai_moderation/__builtin__
Normal file
@@ -0,0 +1 @@
|
||||
1
|
||||
@@ -0,0 +1,60 @@
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
|
||||
|
||||
class OpenAIModeration(Moderation):
|
||||
name: str = "openai_moderation"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
cls._validate_inputs_and_outputs_config(config, True)
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
|
||||
if self.config["inputs_config"]["enabled"]:
|
||||
preset_response = self.config["inputs_config"]["preset_response"]
|
||||
|
||||
if query:
|
||||
inputs["query__"] = query
|
||||
flagged = self._is_violated(inputs)
|
||||
|
||||
return ModerationInputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
|
||||
if self.config["outputs_config"]["enabled"]:
|
||||
flagged = self._is_violated({"text": text})
|
||||
preset_response = self.config["outputs_config"]["preset_response"]
|
||||
|
||||
return ModerationOutputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _is_violated(self, inputs: dict):
|
||||
text = "\n".join(str(inputs.values()))
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest"
|
||||
)
|
||||
|
||||
openai_moderation = model_instance.invoke_moderation(text=text)
|
||||
|
||||
return openai_moderation
|
||||
Reference in New Issue
Block a user