dify
This commit is contained in:
0
dify/api/core/ops/aliyun_trace/__init__.py
Normal file
0
dify/api/core/ops/aliyun_trace/__init__.py
Normal file
519
dify/api/core/ops/aliyun_trace/aliyun_trace.py
Normal file
519
dify/api/core/ops/aliyun_trace/aliyun_trace.py
Normal file
@@ -0,0 +1,519 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
TraceClient,
|
||||
build_endpoint,
|
||||
convert_datetime_to_nanoseconds,
|
||||
convert_to_span_id,
|
||||
convert_to_trace_id,
|
||||
generate_span_id,
|
||||
)
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
GEN_AI_OUTPUT_MESSAGE,
|
||||
GEN_AI_PROMPT,
|
||||
GEN_AI_PROVIDER_NAME,
|
||||
GEN_AI_REQUEST_MODEL,
|
||||
GEN_AI_RESPONSE_FINISH_REASON,
|
||||
GEN_AI_USAGE_INPUT_TOKENS,
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS,
|
||||
GEN_AI_USAGE_TOTAL_TOKENS,
|
||||
RETRIEVAL_DOCUMENT,
|
||||
RETRIEVAL_QUERY,
|
||||
TOOL_DESCRIPTION,
|
||||
TOOL_NAME,
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.ops.aliyun_trace.utils import (
|
||||
create_common_span_attributes,
|
||||
create_links_from_trace_id,
|
||||
create_status_from_error,
|
||||
extract_retrieval_documents,
|
||||
format_input_messages,
|
||||
format_output_messages,
|
||||
format_retrieval_documents,
|
||||
get_user_id_from_message_data,
|
||||
get_workflow_node_status,
|
||||
serialize_json_data,
|
||||
)
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AliyunDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
aliyun_config: AliyunConfig,
|
||||
):
|
||||
super().__init__(aliyun_config)
|
||||
endpoint = build_endpoint(aliyun_config.endpoint, aliyun_config.license_key)
|
||||
self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint)
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
pass
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
pass
|
||||
|
||||
def api_check(self):
|
||||
return self.trace_client.api_check()
|
||||
|
||||
def get_project_url(self):
|
||||
try:
|
||||
return self.trace_client.get_project_url()
|
||||
except Exception as e:
|
||||
logger.info("Aliyun get project url failed: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"Aliyun get project url failed: {str(e)}")
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(trace_info.workflow_run_id),
|
||||
workflow_span_id=convert_to_span_id(trace_info.workflow_run_id, "workflow"),
|
||||
session_id=trace_info.metadata.get("conversation_id") or "",
|
||||
user_id=str(trace_info.metadata.get("user_id") or ""),
|
||||
links=create_links_from_trace_id(trace_info.trace_id),
|
||||
)
|
||||
|
||||
self.add_workflow_span(trace_info, trace_metadata)
|
||||
|
||||
workflow_node_executions = self.get_workflow_node_executions(trace_info)
|
||||
for node_execution in workflow_node_executions:
|
||||
node_span = self.build_workflow_node_span(node_execution, trace_info, trace_metadata)
|
||||
self.trace_client.add_span(node_span)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
|
||||
message_id = trace_info.message_id
|
||||
user_id = get_user_id_from_message_data(message_data)
|
||||
status = create_status_from_error(trace_info.error)
|
||||
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
workflow_span_id=0,
|
||||
session_id=trace_info.metadata.get("conversation_id") or "",
|
||||
user_id=user_id,
|
||||
links=create_links_from_trace_id(trace_info.trace_id),
|
||||
)
|
||||
|
||||
inputs_json = serialize_json_data(trace_info.inputs)
|
||||
outputs_str = str(trace_info.outputs)
|
||||
|
||||
message_span_id = convert_to_span_id(message_id, "message")
|
||||
message_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=message_span_id,
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes=create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_str,
|
||||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
llm_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=message_span_id,
|
||||
span_id=convert_to_span_id(message_id, "llm"),
|
||||
name="llm",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.LLM,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_str,
|
||||
),
|
||||
GEN_AI_REQUEST_MODEL: trace_info.metadata.get("ls_model_name") or "",
|
||||
GEN_AI_PROVIDER_NAME: trace_info.metadata.get("ls_provider") or "",
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens),
|
||||
GEN_AI_PROMPT: inputs_json,
|
||||
GEN_AI_COMPLETION: outputs_str,
|
||||
},
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(llm_span)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
message_id = trace_info.message_id
|
||||
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
workflow_span_id=0,
|
||||
session_id=trace_info.metadata.get("conversation_id") or "",
|
||||
user_id=str(trace_info.metadata.get("user_id") or ""),
|
||||
links=create_links_from_trace_id(trace_info.trace_id),
|
||||
)
|
||||
|
||||
documents_data = extract_retrieval_documents(trace_info.documents)
|
||||
documents_json = serialize_json_data(documents_data)
|
||||
inputs_str = str(trace_info.inputs)
|
||||
|
||||
dataset_retrieval_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=generate_span_id(),
|
||||
name="dataset_retrieval",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.RETRIEVER,
|
||||
inputs=inputs_str,
|
||||
outputs=documents_json,
|
||||
),
|
||||
RETRIEVAL_QUERY: inputs_str,
|
||||
RETRIEVAL_DOCUMENT: documents_json,
|
||||
},
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(dataset_retrieval_span)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
message_id = trace_info.message_id
|
||||
status = create_status_from_error(trace_info.error)
|
||||
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
workflow_span_id=0,
|
||||
session_id=trace_info.metadata.get("conversation_id") or "",
|
||||
user_id=str(trace_info.metadata.get("user_id") or ""),
|
||||
links=create_links_from_trace_id(trace_info.trace_id),
|
||||
)
|
||||
|
||||
tool_config_json = serialize_json_data(trace_info.tool_config)
|
||||
tool_inputs_json = serialize_json_data(trace_info.tool_inputs)
|
||||
inputs_json = serialize_json_data(trace_info.inputs)
|
||||
|
||||
tool_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=generate_span_id(),
|
||||
name=trace_info.tool_name,
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.TOOL,
|
||||
inputs=inputs_json,
|
||||
outputs=str(trace_info.tool_outputs),
|
||||
),
|
||||
TOOL_NAME: trace_info.tool_name,
|
||||
TOOL_DESCRIPTION: tool_config_json,
|
||||
TOOL_PARAMETERS: tool_inputs_json,
|
||||
},
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(tool_span)
|
||||
|
||||
def get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> Sequence[WorkflowNodeExecution]:
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
return workflow_node_execution_repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id)
|
||||
|
||||
def build_workflow_node_span(
|
||||
self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata
|
||||
):
|
||||
try:
|
||||
if node_execution.node_type == NodeType.LLM:
|
||||
node_span = self.build_workflow_llm_span(trace_info, node_execution, trace_metadata)
|
||||
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
node_span = self.build_workflow_retrieval_span(trace_info, node_execution, trace_metadata)
|
||||
elif node_execution.node_type == NodeType.TOOL:
|
||||
node_span = self.build_workflow_tool_span(trace_info, node_execution, trace_metadata)
|
||||
else:
|
||||
node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata)
|
||||
return node_span
|
||||
except Exception as e:
|
||||
logger.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
def build_workflow_task_span(
|
||||
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
|
||||
) -> SpanData:
|
||||
inputs_json = serialize_json_data(node_execution.inputs)
|
||||
outputs_json = serialize_json_data(node_execution.outputs)
|
||||
return SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=trace_metadata.workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes=create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.TASK,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_json,
|
||||
),
|
||||
status=get_workflow_node_status(node_execution),
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
|
||||
def build_workflow_tool_span(
|
||||
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
|
||||
) -> SpanData:
|
||||
tool_des = {}
|
||||
if node_execution.metadata:
|
||||
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
|
||||
|
||||
inputs_json = serialize_json_data(node_execution.inputs or {})
|
||||
outputs_json = serialize_json_data(node_execution.outputs)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=trace_metadata.workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.TOOL,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_json,
|
||||
),
|
||||
TOOL_NAME: node_execution.title,
|
||||
TOOL_DESCRIPTION: serialize_json_data(tool_des),
|
||||
TOOL_PARAMETERS: inputs_json,
|
||||
},
|
||||
status=get_workflow_node_status(node_execution),
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
|
||||
def build_workflow_retrieval_span(
|
||||
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
|
||||
) -> SpanData:
|
||||
input_value = str(node_execution.inputs.get("query", "")) if node_execution.inputs else ""
|
||||
output_value = serialize_json_data(node_execution.outputs.get("result", [])) if node_execution.outputs else ""
|
||||
|
||||
retrieval_documents = node_execution.outputs.get("result", []) if node_execution.outputs else []
|
||||
semantic_retrieval_documents = format_retrieval_documents(retrieval_documents)
|
||||
semantic_retrieval_documents_json = serialize_json_data(semantic_retrieval_documents)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=trace_metadata.workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.RETRIEVER,
|
||||
inputs=input_value,
|
||||
outputs=output_value,
|
||||
),
|
||||
RETRIEVAL_QUERY: input_value,
|
||||
RETRIEVAL_DOCUMENT: semantic_retrieval_documents_json,
|
||||
},
|
||||
status=get_workflow_node_status(node_execution),
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
|
||||
def build_workflow_llm_span(
|
||||
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
|
||||
) -> SpanData:
|
||||
process_data = node_execution.process_data or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
|
||||
prompts_json = serialize_json_data(process_data.get("prompts", []))
|
||||
text_output = str(outputs.get("text", ""))
|
||||
|
||||
gen_ai_input_message = format_input_messages(process_data)
|
||||
gen_ai_output_message = format_output_messages(outputs)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=trace_metadata.workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.LLM,
|
||||
inputs=prompts_json,
|
||||
outputs=text_output,
|
||||
),
|
||||
GEN_AI_REQUEST_MODEL: process_data.get("model_name") or "",
|
||||
GEN_AI_PROVIDER_NAME: process_data.get("model_provider") or "",
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
|
||||
GEN_AI_PROMPT: prompts_json,
|
||||
GEN_AI_COMPLETION: text_output,
|
||||
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason") or "",
|
||||
GEN_AI_INPUT_MESSAGE: gen_ai_input_message,
|
||||
GEN_AI_OUTPUT_MESSAGE: gen_ai_output_message,
|
||||
},
|
||||
status=get_workflow_node_status(node_execution),
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
|
||||
def add_workflow_span(self, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata):
|
||||
message_span_id = None
|
||||
if trace_info.message_id:
|
||||
message_span_id = convert_to_span_id(trace_info.message_id, "message")
|
||||
status = create_status_from_error(trace_info.error)
|
||||
|
||||
inputs_json = serialize_json_data(trace_info.workflow_run_inputs)
|
||||
outputs_json = serialize_json_data(trace_info.workflow_run_outputs)
|
||||
|
||||
if message_span_id:
|
||||
message_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=message_span_id,
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes=create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=trace_info.workflow_run_inputs.get("sys.query") or "",
|
||||
outputs=outputs_json,
|
||||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
workflow_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=message_span_id,
|
||||
span_id=trace_metadata.workflow_span_id,
|
||||
name="workflow",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes=create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_json,
|
||||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(workflow_span)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_id = trace_info.message_id
|
||||
status = create_status_from_error(trace_info.error)
|
||||
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
workflow_span_id=0,
|
||||
session_id=trace_info.metadata.get("conversation_id") or "",
|
||||
user_id=str(trace_info.metadata.get("user_id") or ""),
|
||||
links=create_links_from_trace_id(trace_info.trace_id),
|
||||
)
|
||||
|
||||
inputs_json = serialize_json_data(trace_info.inputs)
|
||||
suggested_question_json = serialize_json_data(trace_info.suggested_question)
|
||||
|
||||
suggested_question_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=convert_to_span_id(message_id, "suggested_question"),
|
||||
name="suggested_question",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.LLM,
|
||||
inputs=inputs_json,
|
||||
outputs=suggested_question_json,
|
||||
),
|
||||
GEN_AI_REQUEST_MODEL: trace_info.metadata.get("ls_model_name") or "",
|
||||
GEN_AI_PROVIDER_NAME: trace_info.metadata.get("ls_provider") or "",
|
||||
GEN_AI_PROMPT: inputs_json,
|
||||
GEN_AI_COMPLETION: suggested_question_json,
|
||||
},
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(suggested_question_span)
|
||||
236
dify/api/core/ops/aliyun_trace/data_exporter/traceclient.py
Normal file
236
dify/api/core/ops/aliyun_trace/data_exporter/traceclient.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import random
|
||||
import socket
|
||||
import threading
|
||||
import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Final, cast
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
|
||||
from configs import dify_config
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
|
||||
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
|
||||
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000
|
||||
DEFAULT_TIMEOUT: Final[int] = 5
|
||||
DEFAULT_MAX_QUEUE_SIZE: Final[int] = 1000
|
||||
DEFAULT_SCHEDULE_DELAY_SEC: Final[int] = 5
|
||||
DEFAULT_MAX_EXPORT_BATCH_SIZE: Final[int] = 50
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TraceClient:
|
||||
def __init__(
|
||||
self,
|
||||
service_name: str,
|
||||
endpoint: str,
|
||||
max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE,
|
||||
schedule_delay_sec: int = DEFAULT_SCHEDULE_DELAY_SEC,
|
||||
max_export_batch_size: int = DEFAULT_MAX_EXPORT_BATCH_SIZE,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: service_name,
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
}
|
||||
)
|
||||
self.span_builder = SpanBuilder(self.resource)
|
||||
self.exporter = OTLPSpanExporter(endpoint=endpoint)
|
||||
|
||||
self.max_queue_size = max_queue_size
|
||||
self.schedule_delay_sec = schedule_delay_sec
|
||||
self.max_export_batch_size = max_export_batch_size
|
||||
|
||||
self.queue: deque = deque(maxlen=max_queue_size)
|
||||
self.condition = threading.Condition(threading.Lock())
|
||||
self.done = False
|
||||
|
||||
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
|
||||
self.worker_thread.start()
|
||||
|
||||
self._spans_dropped = False
|
||||
|
||||
def export(self, spans: Sequence[ReadableSpan]):
|
||||
self.exporter.export(spans)
|
||||
|
||||
def api_check(self) -> bool:
|
||||
try:
|
||||
response = httpx.head(self.endpoint, timeout=DEFAULT_TIMEOUT)
|
||||
if response.status_code == 405:
|
||||
return True
|
||||
else:
|
||||
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
|
||||
return False
|
||||
except httpx.RequestError as e:
|
||||
logger.debug("AliyunTrace API check failed: %s", str(e))
|
||||
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self) -> str:
|
||||
return "https://arms.console.aliyun.com/#/llm"
|
||||
|
||||
def add_span(self, span_data: SpanData | None) -> None:
|
||||
if span_data is None:
|
||||
return
|
||||
|
||||
span: ReadableSpan = self.span_builder.build_span(span_data)
|
||||
with self.condition:
|
||||
if len(self.queue) == self.max_queue_size:
|
||||
if not self._spans_dropped:
|
||||
logger.warning("Queue is full, likely spans will be dropped.")
|
||||
self._spans_dropped = True
|
||||
|
||||
self.queue.appendleft(span)
|
||||
if len(self.queue) >= self.max_export_batch_size:
|
||||
self.condition.notify()
|
||||
|
||||
def _worker(self) -> None:
|
||||
while not self.done:
|
||||
with self.condition:
|
||||
if len(self.queue) < self.max_export_batch_size and not self.done:
|
||||
self.condition.wait(timeout=self.schedule_delay_sec)
|
||||
self._export_batch()
|
||||
|
||||
def _export_batch(self) -> None:
|
||||
spans_to_export: list[ReadableSpan] = []
|
||||
with self.condition:
|
||||
while len(spans_to_export) < self.max_export_batch_size and self.queue:
|
||||
spans_to_export.append(self.queue.pop())
|
||||
|
||||
if spans_to_export:
|
||||
try:
|
||||
self.exporter.export(spans_to_export)
|
||||
except Exception as e:
|
||||
logger.debug("Error exporting spans: %s", e)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
with self.condition:
|
||||
self.done = True
|
||||
self.condition.notify_all()
|
||||
self.worker_thread.join()
|
||||
self._export_batch()
|
||||
self.exporter.shutdown()
|
||||
|
||||
|
||||
class SpanBuilder:
|
||||
def __init__(self, resource: Resource) -> None:
|
||||
self.resource = resource
|
||||
self.instrumentation_scope = InstrumentationScope(
|
||||
__name__,
|
||||
"",
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
def build_span(self, span_data: SpanData) -> ReadableSpan:
|
||||
span_context = trace_api.SpanContext(
|
||||
trace_id=span_data.trace_id,
|
||||
span_id=span_data.span_id,
|
||||
is_remote=False,
|
||||
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
|
||||
trace_state=None,
|
||||
)
|
||||
|
||||
parent_span_context = None
|
||||
if span_data.parent_span_id is not None:
|
||||
parent_span_context = trace_api.SpanContext(
|
||||
trace_id=span_data.trace_id,
|
||||
span_id=span_data.parent_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
|
||||
trace_state=None,
|
||||
)
|
||||
|
||||
span = ReadableSpan(
|
||||
name=span_data.name,
|
||||
context=span_context,
|
||||
parent=parent_span_context,
|
||||
resource=self.resource,
|
||||
attributes=span_data.attributes,
|
||||
events=span_data.events,
|
||||
links=span_data.links,
|
||||
kind=trace_api.SpanKind.INTERNAL,
|
||||
status=span_data.status,
|
||||
start_time=span_data.start_time,
|
||||
end_time=span_data.end_time,
|
||||
instrumentation_scope=self.instrumentation_scope,
|
||||
)
|
||||
return span
|
||||
|
||||
|
||||
def create_link(trace_id_str: str) -> Link:
|
||||
placeholder_span_id = INVALID_SPAN_ID
|
||||
try:
|
||||
trace_id = int(trace_id_str, 16)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid trace ID format: {trace_id_str}") from e
|
||||
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id, span_id=placeholder_span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED)
|
||||
)
|
||||
|
||||
return Link(span_context)
|
||||
|
||||
|
||||
def generate_span_id() -> int:
|
||||
span_id = random.getrandbits(64)
|
||||
while span_id == INVALID_SPAN_ID:
|
||||
span_id = random.getrandbits(64)
|
||||
return span_id
|
||||
|
||||
|
||||
def convert_to_trace_id(uuid_v4: str | None) -> int:
|
||||
if uuid_v4 is None:
|
||||
raise ValueError("UUID cannot be None")
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
return cast(int, uuid_obj.int)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e
|
||||
|
||||
|
||||
def convert_string_to_id(string: str | None) -> int:
|
||||
if not string:
|
||||
return generate_span_id()
|
||||
hash_bytes = hashlib.sha256(string.encode("utf-8")).digest()
|
||||
return int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
|
||||
|
||||
|
||||
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
|
||||
if uuid_v4 is None:
|
||||
raise ValueError("UUID cannot be None")
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e
|
||||
combined_key = f"{uuid_obj.hex}-{span_type}"
|
||||
return convert_string_to_id(combined_key)
|
||||
|
||||
|
||||
def convert_datetime_to_nanoseconds(start_time_a: datetime | None) -> int | None:
|
||||
if start_time_a is None:
|
||||
return None
|
||||
timestamp_in_seconds = start_time_a.timestamp()
|
||||
return int(timestamp_in_seconds * 1e9)
|
||||
|
||||
|
||||
def build_endpoint(base_url: str, license_key: str) -> str:
|
||||
if "log.aliyuncs.com" in base_url: # cms2.0 endpoint
|
||||
return urljoin(base_url, f"adapt_{license_key}/api/v1/traces")
|
||||
else: # xtrace endpoint
|
||||
return urljoin(base_url, f"adapt_{license_key}/api/otlp/traces")
|
||||
0
dify/api/core/ops/aliyun_trace/entities/__init__.py
Normal file
0
dify/api/core/ops/aliyun_trace/entities/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceMetadata:
|
||||
"""Metadata for trace operations, containing common attributes for all spans in a trace."""
|
||||
|
||||
trace_id: int
|
||||
workflow_span_id: int
|
||||
session_id: str
|
||||
user_id: str
|
||||
links: list[trace_api.Link]
|
||||
|
||||
|
||||
class SpanData(BaseModel):
|
||||
"""Data model for span information in Aliyun trace system."""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
trace_id: int = Field(..., description="The unique identifier for the trace.")
|
||||
parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.")
|
||||
span_id: int = Field(..., description="The unique identifier for this span.")
|
||||
name: str = Field(..., description="The name of the span.")
|
||||
attributes: dict[str, Any] = Field(default_factory=dict, description="Attributes associated with the span.")
|
||||
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
|
||||
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
|
||||
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
|
||||
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
|
||||
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
|
||||
46
dify/api/core/ops/aliyun_trace/entities/semconv.py
Normal file
46
dify/api/core/ops/aliyun_trace/entities/semconv.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from enum import StrEnum
|
||||
from typing import Final
|
||||
|
||||
# Public attributes
|
||||
GEN_AI_SESSION_ID: Final[str] = "gen_ai.session.id"
|
||||
GEN_AI_USER_ID: Final[str] = "gen_ai.user.id"
|
||||
GEN_AI_USER_NAME: Final[str] = "gen_ai.user.name"
|
||||
GEN_AI_SPAN_KIND: Final[str] = "gen_ai.span.kind"
|
||||
GEN_AI_FRAMEWORK: Final[str] = "gen_ai.framework"
|
||||
|
||||
# Chain attributes
|
||||
INPUT_VALUE: Final[str] = "input.value"
|
||||
OUTPUT_VALUE: Final[str] = "output.value"
|
||||
|
||||
# Retriever attributes
|
||||
RETRIEVAL_QUERY: Final[str] = "retrieval.query"
|
||||
RETRIEVAL_DOCUMENT: Final[str] = "retrieval.document"
|
||||
|
||||
# LLM attributes
|
||||
GEN_AI_REQUEST_MODEL: Final[str] = "gen_ai.request.model"
|
||||
GEN_AI_PROVIDER_NAME: Final[str] = "gen_ai.provider.name"
|
||||
GEN_AI_USAGE_INPUT_TOKENS: Final[str] = "gen_ai.usage.input_tokens"
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: Final[str] = "gen_ai.usage.output_tokens"
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: Final[str] = "gen_ai.usage.total_tokens"
|
||||
GEN_AI_PROMPT: Final[str] = "gen_ai.prompt"
|
||||
GEN_AI_COMPLETION: Final[str] = "gen_ai.completion"
|
||||
GEN_AI_RESPONSE_FINISH_REASON: Final[str] = "gen_ai.response.finish_reason"
|
||||
|
||||
GEN_AI_INPUT_MESSAGE: Final[str] = "gen_ai.input.messages"
|
||||
GEN_AI_OUTPUT_MESSAGE: Final[str] = "gen_ai.output.messages"
|
||||
|
||||
# Tool attributes
|
||||
TOOL_NAME: Final[str] = "tool.name"
|
||||
TOOL_DESCRIPTION: Final[str] = "tool.description"
|
||||
TOOL_PARAMETERS: Final[str] = "tool.parameters"
|
||||
|
||||
|
||||
class GenAISpanKind(StrEnum):
|
||||
CHAIN = "CHAIN"
|
||||
RETRIEVER = "RETRIEVER"
|
||||
RERANKER = "RERANKER"
|
||||
LLM = "LLM"
|
||||
EMBEDDING = "EMBEDDING"
|
||||
TOOL = "TOOL"
|
||||
AGENT = "AGENT"
|
||||
TASK = "TASK"
|
||||
190
dify/api/core/ops/aliyun_trace/utils.py
Normal file
190
dify/api/core/ops/aliyun_trace/utils.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry.trace import Link, Status, StatusCode
|
||||
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
GEN_AI_USER_ID,
|
||||
INPUT_VALUE,
|
||||
OUTPUT_VALUE,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser
|
||||
|
||||
# Constants
|
||||
DEFAULT_JSON_ENSURE_ASCII = False
|
||||
DEFAULT_FRAMEWORK_NAME = "dify"
|
||||
|
||||
|
||||
def get_user_id_from_message_data(message_data) -> str:
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
user_id = end_user_data.session_id
|
||||
return user_id
|
||||
|
||||
|
||||
def create_status_from_error(error: str | None) -> Status:
|
||||
if error:
|
||||
return Status(StatusCode.ERROR, error)
|
||||
return Status(StatusCode.OK)
|
||||
|
||||
|
||||
def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
|
||||
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return Status(StatusCode.OK)
|
||||
if node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
|
||||
return Status(StatusCode.ERROR, str(node_execution.error))
|
||||
return Status(StatusCode.UNSET)
|
||||
|
||||
|
||||
def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import create_link
|
||||
|
||||
links = []
|
||||
if trace_id:
|
||||
links.append(create_link(trace_id_str=trace_id))
|
||||
return links
|
||||
|
||||
|
||||
def extract_retrieval_documents(documents: list[Document]) -> list[dict[str, Any]]:
|
||||
documents_data = []
|
||||
for document in documents:
|
||||
document_data = {
|
||||
"content": document.page_content,
|
||||
"metadata": {
|
||||
"dataset_id": document.metadata.get("dataset_id"),
|
||||
"doc_id": document.metadata.get("doc_id"),
|
||||
"document_id": document.metadata.get("document_id"),
|
||||
},
|
||||
"score": document.metadata.get("score"),
|
||||
}
|
||||
documents_data.append(document_data)
|
||||
return documents_data
|
||||
|
||||
|
||||
def serialize_json_data(data: Any, ensure_ascii: bool = DEFAULT_JSON_ENSURE_ASCII) -> str:
|
||||
return json.dumps(data, ensure_ascii=ensure_ascii)
|
||||
|
||||
|
||||
def create_common_span_attributes(
|
||||
session_id: str = "",
|
||||
user_id: str = "",
|
||||
span_kind: str = GenAISpanKind.CHAIN,
|
||||
framework: str = DEFAULT_FRAMEWORK_NAME,
|
||||
inputs: str = "",
|
||||
outputs: str = "",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
GEN_AI_SESSION_ID: session_id,
|
||||
GEN_AI_USER_ID: user_id,
|
||||
GEN_AI_SPAN_KIND: span_kind,
|
||||
GEN_AI_FRAMEWORK: framework,
|
||||
INPUT_VALUE: inputs,
|
||||
OUTPUT_VALUE: outputs,
|
||||
}
|
||||
|
||||
|
||||
def format_retrieval_documents(retrieval_documents: list) -> list:
|
||||
try:
|
||||
if not isinstance(retrieval_documents, list):
|
||||
return []
|
||||
|
||||
semantic_documents = []
|
||||
for doc in retrieval_documents:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
|
||||
metadata = doc.get("metadata", {})
|
||||
content = doc.get("content", "")
|
||||
title = doc.get("title", "")
|
||||
score = metadata.get("score", 0.0)
|
||||
document_id = metadata.get("document_id", "")
|
||||
|
||||
semantic_metadata = {}
|
||||
if title:
|
||||
semantic_metadata["title"] = title
|
||||
if metadata.get("source"):
|
||||
semantic_metadata["source"] = metadata["source"]
|
||||
elif metadata.get("_source"):
|
||||
semantic_metadata["source"] = metadata["_source"]
|
||||
if metadata.get("doc_metadata"):
|
||||
doc_metadata = metadata["doc_metadata"]
|
||||
if isinstance(doc_metadata, dict):
|
||||
semantic_metadata.update(doc_metadata)
|
||||
|
||||
semantic_doc = {
|
||||
"document": {"content": content, "metadata": semantic_metadata, "score": score, "id": document_id}
|
||||
}
|
||||
semantic_documents.append(semantic_doc)
|
||||
|
||||
return semantic_documents
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def format_input_messages(process_data: Mapping[str, Any]) -> str:
|
||||
try:
|
||||
if not isinstance(process_data, dict):
|
||||
return serialize_json_data([])
|
||||
|
||||
prompts = process_data.get("prompts", [])
|
||||
if not prompts:
|
||||
return serialize_json_data([])
|
||||
|
||||
valid_roles = {"system", "user", "assistant", "tool"}
|
||||
input_messages = []
|
||||
for prompt in prompts:
|
||||
if not isinstance(prompt, dict):
|
||||
continue
|
||||
|
||||
role = prompt.get("role", "")
|
||||
text = prompt.get("text", "")
|
||||
|
||||
if not role or role not in valid_roles:
|
||||
continue
|
||||
|
||||
if text:
|
||||
message = {"role": role, "parts": [{"type": "text", "content": text}]}
|
||||
input_messages.append(message)
|
||||
|
||||
return serialize_json_data(input_messages)
|
||||
except Exception:
|
||||
return serialize_json_data([])
|
||||
|
||||
|
||||
def format_output_messages(outputs: Mapping[str, Any]) -> str:
|
||||
try:
|
||||
if not isinstance(outputs, dict):
|
||||
return serialize_json_data([])
|
||||
|
||||
text = outputs.get("text", "")
|
||||
finish_reason = outputs.get("finish_reason", "")
|
||||
|
||||
if not text:
|
||||
return serialize_json_data([])
|
||||
|
||||
valid_finish_reasons = {"stop", "length", "content_filter", "tool_call", "error"}
|
||||
if finish_reason not in valid_finish_reasons:
|
||||
finish_reason = "stop"
|
||||
|
||||
output_message = {
|
||||
"role": "assistant",
|
||||
"parts": [{"type": "text", "content": text}],
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
return serialize_json_data([output_message])
|
||||
except Exception:
|
||||
return serialize_json_data([])
|
||||
Reference in New Issue
Block a user