This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

View File

@@ -0,0 +1,519 @@
import logging
from collections.abc import Sequence
from sqlalchemy.orm import sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
TraceClient,
build_endpoint,
convert_datetime_to_nanoseconds,
convert_to_span_id,
convert_to_trace_id,
generate_span_id,
)
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
from core.ops.aliyun_trace.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
GEN_AI_OUTPUT_MESSAGE,
GEN_AI_PROMPT,
GEN_AI_PROVIDER_NAME,
GEN_AI_REQUEST_MODEL,
GEN_AI_RESPONSE_FINISH_REASON,
GEN_AI_USAGE_INPUT_TOKENS,
GEN_AI_USAGE_OUTPUT_TOKENS,
GEN_AI_USAGE_TOTAL_TOKENS,
RETRIEVAL_DOCUMENT,
RETRIEVAL_QUERY,
TOOL_DESCRIPTION,
TOOL_NAME,
TOOL_PARAMETERS,
GenAISpanKind,
)
from core.ops.aliyun_trace.utils import (
create_common_span_attributes,
create_links_from_trace_id,
create_status_from_error,
extract_retrieval_documents,
format_input_messages,
format_output_messages,
format_retrieval_documents,
get_user_id_from_message_data,
get_workflow_node_status,
serialize_json_data,
)
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import AliyunConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
class AliyunDataTrace(BaseTraceInstance):
def __init__(
self,
aliyun_config: AliyunConfig,
):
super().__init__(aliyun_config)
endpoint = build_endpoint(aliyun_config.endpoint, aliyun_config.license_key)
self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint)
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
pass
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
pass
def api_check(self):
return self.trace_client.api_check()
def get_project_url(self):
try:
return self.trace_client.get_project_url()
except Exception as e:
logger.info("Aliyun get project url failed: %s", str(e), exc_info=True)
raise ValueError(f"Aliyun get project url failed: {str(e)}")
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_metadata = TraceMetadata(
trace_id=convert_to_trace_id(trace_info.workflow_run_id),
workflow_span_id=convert_to_span_id(trace_info.workflow_run_id, "workflow"),
session_id=trace_info.metadata.get("conversation_id") or "",
user_id=str(trace_info.metadata.get("user_id") or ""),
links=create_links_from_trace_id(trace_info.trace_id),
)
self.add_workflow_span(trace_info, trace_metadata)
workflow_node_executions = self.get_workflow_node_executions(trace_info)
for node_execution in workflow_node_executions:
node_span = self.build_workflow_node_span(node_execution, trace_info, trace_metadata)
self.trace_client.add_span(node_span)
def message_trace(self, trace_info: MessageTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
message_id = trace_info.message_id
user_id = get_user_id_from_message_data(message_data)
status = create_status_from_error(trace_info.error)
trace_metadata = TraceMetadata(
trace_id=convert_to_trace_id(message_id),
workflow_span_id=0,
session_id=trace_info.metadata.get("conversation_id") or "",
user_id=user_id,
links=create_links_from_trace_id(trace_info.trace_id),
)
inputs_json = serialize_json_data(trace_info.inputs)
outputs_str = str(trace_info.outputs)
message_span_id = convert_to_span_id(message_id, "message")
message_span = SpanData(
trace_id=trace_metadata.trace_id,
parent_span_id=None,
span_id=message_span_id,
name="message",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes=create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=inputs_json,
outputs=outputs_str,
),
status=status,
links=trace_metadata.links,
)
self.trace_client.add_span(message_span)
llm_span = SpanData(
trace_id=trace_metadata.trace_id,
parent_span_id=message_span_id,
span_id=convert_to_span_id(message_id, "llm"),
name="llm",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.LLM,
inputs=inputs_json,
outputs=outputs_str,
),
GEN_AI_REQUEST_MODEL: trace_info.metadata.get("ls_model_name") or "",
GEN_AI_PROVIDER_NAME: trace_info.metadata.get("ls_provider") or "",
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens),
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens),
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens),
GEN_AI_PROMPT: inputs_json,
GEN_AI_COMPLETION: outputs_str,
},
status=status,
links=trace_metadata.links,
)
self.trace_client.add_span(llm_span)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
message_id = trace_info.message_id
trace_metadata = TraceMetadata(
trace_id=convert_to_trace_id(message_id),
workflow_span_id=0,
session_id=trace_info.metadata.get("conversation_id") or "",
user_id=str(trace_info.metadata.get("user_id") or ""),
links=create_links_from_trace_id(trace_info.trace_id),
)
documents_data = extract_retrieval_documents(trace_info.documents)
documents_json = serialize_json_data(documents_data)
inputs_str = str(trace_info.inputs)
dataset_retrieval_span = SpanData(
trace_id=trace_metadata.trace_id,
parent_span_id=convert_to_span_id(message_id, "message"),
span_id=generate_span_id(),
name="dataset_retrieval",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.RETRIEVER,
inputs=inputs_str,
outputs=documents_json,
),
RETRIEVAL_QUERY: inputs_str,
RETRIEVAL_DOCUMENT: documents_json,
},
links=trace_metadata.links,
)
self.trace_client.add_span(dataset_retrieval_span)
def tool_trace(self, trace_info: ToolTraceInfo):
if trace_info.message_data is None:
return
message_id = trace_info.message_id
status = create_status_from_error(trace_info.error)
trace_metadata = TraceMetadata(
trace_id=convert_to_trace_id(message_id),
workflow_span_id=0,
session_id=trace_info.metadata.get("conversation_id") or "",
user_id=str(trace_info.metadata.get("user_id") or ""),
links=create_links_from_trace_id(trace_info.trace_id),
)
tool_config_json = serialize_json_data(trace_info.tool_config)
tool_inputs_json = serialize_json_data(trace_info.tool_inputs)
inputs_json = serialize_json_data(trace_info.inputs)
tool_span = SpanData(
trace_id=trace_metadata.trace_id,
parent_span_id=convert_to_span_id(message_id, "message"),
span_id=generate_span_id(),
name=trace_info.tool_name,
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.TOOL,
inputs=inputs_json,
outputs=str(trace_info.tool_outputs),
),
TOOL_NAME: trace_info.tool_name,
TOOL_DESCRIPTION: tool_config_json,
TOOL_PARAMETERS: tool_inputs_json,
},
status=status,
links=trace_metadata.links,
)
self.trace_client.add_span(tool_span)
def get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> Sequence[WorkflowNodeExecution]:
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
service_account = self.get_service_account_with_tenant(app_id)
session_factory = sessionmaker(bind=db.engine)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
return workflow_node_execution_repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id)
def build_workflow_node_span(
self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata
):
try:
if node_execution.node_type == NodeType.LLM:
node_span = self.build_workflow_llm_span(trace_info, node_execution, trace_metadata)
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
node_span = self.build_workflow_retrieval_span(trace_info, node_execution, trace_metadata)
elif node_execution.node_type == NodeType.TOOL:
node_span = self.build_workflow_tool_span(trace_info, node_execution, trace_metadata)
else:
node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata)
return node_span
except Exception as e:
logger.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
return None
def build_workflow_task_span(
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
) -> SpanData:
inputs_json = serialize_json_data(node_execution.inputs)
outputs_json = serialize_json_data(node_execution.outputs)
return SpanData(
trace_id=trace_metadata.trace_id,
parent_span_id=trace_metadata.workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes=create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.TASK,
inputs=inputs_json,
outputs=outputs_json,
),
status=get_workflow_node_status(node_execution),
links=trace_metadata.links,
)
def build_workflow_tool_span(
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
) -> SpanData:
tool_des = {}
if node_execution.metadata:
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
inputs_json = serialize_json_data(node_execution.inputs or {})
outputs_json = serialize_json_data(node_execution.outputs)
return SpanData(
trace_id=trace_metadata.trace_id,
parent_span_id=trace_metadata.workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.TOOL,
inputs=inputs_json,
outputs=outputs_json,
),
TOOL_NAME: node_execution.title,
TOOL_DESCRIPTION: serialize_json_data(tool_des),
TOOL_PARAMETERS: inputs_json,
},
status=get_workflow_node_status(node_execution),
links=trace_metadata.links,
)
def build_workflow_retrieval_span(
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
) -> SpanData:
input_value = str(node_execution.inputs.get("query", "")) if node_execution.inputs else ""
output_value = serialize_json_data(node_execution.outputs.get("result", [])) if node_execution.outputs else ""
retrieval_documents = node_execution.outputs.get("result", []) if node_execution.outputs else []
semantic_retrieval_documents = format_retrieval_documents(retrieval_documents)
semantic_retrieval_documents_json = serialize_json_data(semantic_retrieval_documents)
return SpanData(
trace_id=trace_metadata.trace_id,
parent_span_id=trace_metadata.workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.RETRIEVER,
inputs=input_value,
outputs=output_value,
),
RETRIEVAL_QUERY: input_value,
RETRIEVAL_DOCUMENT: semantic_retrieval_documents_json,
},
status=get_workflow_node_status(node_execution),
links=trace_metadata.links,
)
def build_workflow_llm_span(
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
) -> SpanData:
process_data = node_execution.process_data or {}
outputs = node_execution.outputs or {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
prompts_json = serialize_json_data(process_data.get("prompts", []))
text_output = str(outputs.get("text", ""))
gen_ai_input_message = format_input_messages(process_data)
gen_ai_output_message = format_output_messages(outputs)
return SpanData(
trace_id=trace_metadata.trace_id,
parent_span_id=trace_metadata.workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.LLM,
inputs=prompts_json,
outputs=text_output,
),
GEN_AI_REQUEST_MODEL: process_data.get("model_name") or "",
GEN_AI_PROVIDER_NAME: process_data.get("model_provider") or "",
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
GEN_AI_PROMPT: prompts_json,
GEN_AI_COMPLETION: text_output,
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason") or "",
GEN_AI_INPUT_MESSAGE: gen_ai_input_message,
GEN_AI_OUTPUT_MESSAGE: gen_ai_output_message,
},
status=get_workflow_node_status(node_execution),
links=trace_metadata.links,
)
def add_workflow_span(self, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata):
message_span_id = None
if trace_info.message_id:
message_span_id = convert_to_span_id(trace_info.message_id, "message")
status = create_status_from_error(trace_info.error)
inputs_json = serialize_json_data(trace_info.workflow_run_inputs)
outputs_json = serialize_json_data(trace_info.workflow_run_outputs)
if message_span_id:
message_span = SpanData(
trace_id=trace_metadata.trace_id,
parent_span_id=None,
span_id=message_span_id,
name="message",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes=create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=trace_info.workflow_run_inputs.get("sys.query") or "",
outputs=outputs_json,
),
status=status,
links=trace_metadata.links,
)
self.trace_client.add_span(message_span)
workflow_span = SpanData(
trace_id=trace_metadata.trace_id,
parent_span_id=message_span_id,
span_id=trace_metadata.workflow_span_id,
name="workflow",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes=create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=inputs_json,
outputs=outputs_json,
),
status=status,
links=trace_metadata.links,
)
self.trace_client.add_span(workflow_span)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_id = trace_info.message_id
status = create_status_from_error(trace_info.error)
trace_metadata = TraceMetadata(
trace_id=convert_to_trace_id(message_id),
workflow_span_id=0,
session_id=trace_info.metadata.get("conversation_id") or "",
user_id=str(trace_info.metadata.get("user_id") or ""),
links=create_links_from_trace_id(trace_info.trace_id),
)
inputs_json = serialize_json_data(trace_info.inputs)
suggested_question_json = serialize_json_data(trace_info.suggested_question)
suggested_question_span = SpanData(
trace_id=trace_metadata.trace_id,
parent_span_id=convert_to_span_id(message_id, "message"),
span_id=convert_to_span_id(message_id, "suggested_question"),
name="suggested_question",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.LLM,
inputs=inputs_json,
outputs=suggested_question_json,
),
GEN_AI_REQUEST_MODEL: trace_info.metadata.get("ls_model_name") or "",
GEN_AI_PROVIDER_NAME: trace_info.metadata.get("ls_provider") or "",
GEN_AI_PROMPT: inputs_json,
GEN_AI_COMPLETION: suggested_question_json,
},
status=status,
links=trace_metadata.links,
)
self.trace_client.add_span(suggested_question_span)

View File

@@ -0,0 +1,236 @@
import hashlib
import logging
import random
import socket
import threading
import uuid
from collections import deque
from collections.abc import Sequence
from datetime import datetime
from typing import Final, cast
from urllib.parse import urljoin
import httpx
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import Link, SpanContext, TraceFlags
from configs import dify_config
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000
DEFAULT_TIMEOUT: Final[int] = 5
DEFAULT_MAX_QUEUE_SIZE: Final[int] = 1000
DEFAULT_SCHEDULE_DELAY_SEC: Final[int] = 5
DEFAULT_MAX_EXPORT_BATCH_SIZE: Final[int] = 50
logger = logging.getLogger(__name__)
class TraceClient:
def __init__(
self,
service_name: str,
endpoint: str,
max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE,
schedule_delay_sec: int = DEFAULT_SCHEDULE_DELAY_SEC,
max_export_batch_size: int = DEFAULT_MAX_EXPORT_BATCH_SIZE,
):
self.endpoint = endpoint
self.resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
}
)
self.span_builder = SpanBuilder(self.resource)
self.exporter = OTLPSpanExporter(endpoint=endpoint)
self.max_queue_size = max_queue_size
self.schedule_delay_sec = schedule_delay_sec
self.max_export_batch_size = max_export_batch_size
self.queue: deque = deque(maxlen=max_queue_size)
self.condition = threading.Condition(threading.Lock())
self.done = False
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
self.worker_thread.start()
self._spans_dropped = False
def export(self, spans: Sequence[ReadableSpan]):
self.exporter.export(spans)
def api_check(self) -> bool:
try:
response = httpx.head(self.endpoint, timeout=DEFAULT_TIMEOUT)
if response.status_code == 405:
return True
else:
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
return False
except httpx.RequestError as e:
logger.debug("AliyunTrace API check failed: %s", str(e))
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
def get_project_url(self) -> str:
return "https://arms.console.aliyun.com/#/llm"
def add_span(self, span_data: SpanData | None) -> None:
if span_data is None:
return
span: ReadableSpan = self.span_builder.build_span(span_data)
with self.condition:
if len(self.queue) == self.max_queue_size:
if not self._spans_dropped:
logger.warning("Queue is full, likely spans will be dropped.")
self._spans_dropped = True
self.queue.appendleft(span)
if len(self.queue) >= self.max_export_batch_size:
self.condition.notify()
def _worker(self) -> None:
while not self.done:
with self.condition:
if len(self.queue) < self.max_export_batch_size and not self.done:
self.condition.wait(timeout=self.schedule_delay_sec)
self._export_batch()
def _export_batch(self) -> None:
spans_to_export: list[ReadableSpan] = []
with self.condition:
while len(spans_to_export) < self.max_export_batch_size and self.queue:
spans_to_export.append(self.queue.pop())
if spans_to_export:
try:
self.exporter.export(spans_to_export)
except Exception as e:
logger.debug("Error exporting spans: %s", e)
def shutdown(self) -> None:
with self.condition:
self.done = True
self.condition.notify_all()
self.worker_thread.join()
self._export_batch()
self.exporter.shutdown()
class SpanBuilder:
def __init__(self, resource: Resource) -> None:
self.resource = resource
self.instrumentation_scope = InstrumentationScope(
__name__,
"",
None,
None,
)
def build_span(self, span_data: SpanData) -> ReadableSpan:
span_context = trace_api.SpanContext(
trace_id=span_data.trace_id,
span_id=span_data.span_id,
is_remote=False,
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
trace_state=None,
)
parent_span_context = None
if span_data.parent_span_id is not None:
parent_span_context = trace_api.SpanContext(
trace_id=span_data.trace_id,
span_id=span_data.parent_span_id,
is_remote=False,
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
trace_state=None,
)
span = ReadableSpan(
name=span_data.name,
context=span_context,
parent=parent_span_context,
resource=self.resource,
attributes=span_data.attributes,
events=span_data.events,
links=span_data.links,
kind=trace_api.SpanKind.INTERNAL,
status=span_data.status,
start_time=span_data.start_time,
end_time=span_data.end_time,
instrumentation_scope=self.instrumentation_scope,
)
return span
def create_link(trace_id_str: str) -> Link:
placeholder_span_id = INVALID_SPAN_ID
try:
trace_id = int(trace_id_str, 16)
except ValueError as e:
raise ValueError(f"Invalid trace ID format: {trace_id_str}") from e
span_context = SpanContext(
trace_id=trace_id, span_id=placeholder_span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED)
)
return Link(span_context)
def generate_span_id() -> int:
span_id = random.getrandbits(64)
while span_id == INVALID_SPAN_ID:
span_id = random.getrandbits(64)
return span_id
def convert_to_trace_id(uuid_v4: str | None) -> int:
if uuid_v4 is None:
raise ValueError("UUID cannot be None")
try:
uuid_obj = uuid.UUID(uuid_v4)
return cast(int, uuid_obj.int)
except ValueError as e:
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e
def convert_string_to_id(string: str | None) -> int:
if not string:
return generate_span_id()
hash_bytes = hashlib.sha256(string.encode("utf-8")).digest()
return int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
if uuid_v4 is None:
raise ValueError("UUID cannot be None")
try:
uuid_obj = uuid.UUID(uuid_v4)
except ValueError as e:
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e
combined_key = f"{uuid_obj.hex}-{span_type}"
return convert_string_to_id(combined_key)
def convert_datetime_to_nanoseconds(start_time_a: datetime | None) -> int | None:
if start_time_a is None:
return None
timestamp_in_seconds = start_time_a.timestamp()
return int(timestamp_in_seconds * 1e9)
def build_endpoint(base_url: str, license_key: str) -> str:
if "log.aliyuncs.com" in base_url: # cms2.0 endpoint
return urljoin(base_url, f"adapt_{license_key}/api/v1/traces")
else: # xtrace endpoint
return urljoin(base_url, f"adapt_{license_key}/api/otlp/traces")

View File

@@ -0,0 +1,36 @@
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from pydantic import BaseModel, Field
@dataclass
class TraceMetadata:
"""Metadata for trace operations, containing common attributes for all spans in a trace."""
trace_id: int
workflow_span_id: int
session_id: str
user_id: str
links: list[trace_api.Link]
class SpanData(BaseModel):
"""Data model for span information in Aliyun trace system."""
model_config = {"arbitrary_types_allowed": True}
trace_id: int = Field(..., description="The unique identifier for the trace.")
parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.")
span_id: int = Field(..., description="The unique identifier for this span.")
name: str = Field(..., description="The name of the span.")
attributes: dict[str, Any] = Field(default_factory=dict, description="Attributes associated with the span.")
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")

View File

@@ -0,0 +1,46 @@
from enum import StrEnum
from typing import Final
# Public attributes
GEN_AI_SESSION_ID: Final[str] = "gen_ai.session.id"
GEN_AI_USER_ID: Final[str] = "gen_ai.user.id"
GEN_AI_USER_NAME: Final[str] = "gen_ai.user.name"
GEN_AI_SPAN_KIND: Final[str] = "gen_ai.span.kind"
GEN_AI_FRAMEWORK: Final[str] = "gen_ai.framework"
# Chain attributes
INPUT_VALUE: Final[str] = "input.value"
OUTPUT_VALUE: Final[str] = "output.value"
# Retriever attributes
RETRIEVAL_QUERY: Final[str] = "retrieval.query"
RETRIEVAL_DOCUMENT: Final[str] = "retrieval.document"
# LLM attributes
GEN_AI_REQUEST_MODEL: Final[str] = "gen_ai.request.model"
GEN_AI_PROVIDER_NAME: Final[str] = "gen_ai.provider.name"
GEN_AI_USAGE_INPUT_TOKENS: Final[str] = "gen_ai.usage.input_tokens"
GEN_AI_USAGE_OUTPUT_TOKENS: Final[str] = "gen_ai.usage.output_tokens"
GEN_AI_USAGE_TOTAL_TOKENS: Final[str] = "gen_ai.usage.total_tokens"
GEN_AI_PROMPT: Final[str] = "gen_ai.prompt"
GEN_AI_COMPLETION: Final[str] = "gen_ai.completion"
GEN_AI_RESPONSE_FINISH_REASON: Final[str] = "gen_ai.response.finish_reason"
GEN_AI_INPUT_MESSAGE: Final[str] = "gen_ai.input.messages"
GEN_AI_OUTPUT_MESSAGE: Final[str] = "gen_ai.output.messages"
# Tool attributes
TOOL_NAME: Final[str] = "tool.name"
TOOL_DESCRIPTION: Final[str] = "tool.description"
TOOL_PARAMETERS: Final[str] = "tool.parameters"
class GenAISpanKind(StrEnum):
CHAIN = "CHAIN"
RETRIEVER = "RETRIEVER"
RERANKER = "RERANKER"
LLM = "LLM"
EMBEDDING = "EMBEDDING"
TOOL = "TOOL"
AGENT = "AGENT"
TASK = "TASK"

View File

@@ -0,0 +1,190 @@
import json
from collections.abc import Mapping
from typing import Any
from opentelemetry.trace import Link, Status, StatusCode
from core.ops.aliyun_trace.entities.semconv import (
GEN_AI_FRAMEWORK,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
GEN_AI_USER_ID,
INPUT_VALUE,
OUTPUT_VALUE,
GenAISpanKind,
)
from core.rag.models.document import Document
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import WorkflowNodeExecutionStatus
from extensions.ext_database import db
from models import EndUser
# Constants
DEFAULT_JSON_ENSURE_ASCII = False
DEFAULT_FRAMEWORK_NAME = "dify"
def get_user_id_from_message_data(message_data) -> str:
user_id = message_data.from_account_id
if message_data.from_end_user_id:
end_user_data: EndUser | None = (
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
user_id = end_user_data.session_id
return user_id
def create_status_from_error(error: str | None) -> Status:
if error:
return Status(StatusCode.ERROR, error)
return Status(StatusCode.OK)
def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
return Status(StatusCode.OK)
if node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
return Status(StatusCode.ERROR, str(node_execution.error))
return Status(StatusCode.UNSET)
def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
from core.ops.aliyun_trace.data_exporter.traceclient import create_link
links = []
if trace_id:
links.append(create_link(trace_id_str=trace_id))
return links
def extract_retrieval_documents(documents: list[Document]) -> list[dict[str, Any]]:
documents_data = []
for document in documents:
document_data = {
"content": document.page_content,
"metadata": {
"dataset_id": document.metadata.get("dataset_id"),
"doc_id": document.metadata.get("doc_id"),
"document_id": document.metadata.get("document_id"),
},
"score": document.metadata.get("score"),
}
documents_data.append(document_data)
return documents_data
def serialize_json_data(data: Any, ensure_ascii: bool = DEFAULT_JSON_ENSURE_ASCII) -> str:
return json.dumps(data, ensure_ascii=ensure_ascii)
def create_common_span_attributes(
session_id: str = "",
user_id: str = "",
span_kind: str = GenAISpanKind.CHAIN,
framework: str = DEFAULT_FRAMEWORK_NAME,
inputs: str = "",
outputs: str = "",
) -> dict[str, Any]:
return {
GEN_AI_SESSION_ID: session_id,
GEN_AI_USER_ID: user_id,
GEN_AI_SPAN_KIND: span_kind,
GEN_AI_FRAMEWORK: framework,
INPUT_VALUE: inputs,
OUTPUT_VALUE: outputs,
}
def format_retrieval_documents(retrieval_documents: list) -> list:
try:
if not isinstance(retrieval_documents, list):
return []
semantic_documents = []
for doc in retrieval_documents:
if not isinstance(doc, dict):
continue
metadata = doc.get("metadata", {})
content = doc.get("content", "")
title = doc.get("title", "")
score = metadata.get("score", 0.0)
document_id = metadata.get("document_id", "")
semantic_metadata = {}
if title:
semantic_metadata["title"] = title
if metadata.get("source"):
semantic_metadata["source"] = metadata["source"]
elif metadata.get("_source"):
semantic_metadata["source"] = metadata["_source"]
if metadata.get("doc_metadata"):
doc_metadata = metadata["doc_metadata"]
if isinstance(doc_metadata, dict):
semantic_metadata.update(doc_metadata)
semantic_doc = {
"document": {"content": content, "metadata": semantic_metadata, "score": score, "id": document_id}
}
semantic_documents.append(semantic_doc)
return semantic_documents
except Exception:
return []
def format_input_messages(process_data: Mapping[str, Any]) -> str:
try:
if not isinstance(process_data, dict):
return serialize_json_data([])
prompts = process_data.get("prompts", [])
if not prompts:
return serialize_json_data([])
valid_roles = {"system", "user", "assistant", "tool"}
input_messages = []
for prompt in prompts:
if not isinstance(prompt, dict):
continue
role = prompt.get("role", "")
text = prompt.get("text", "")
if not role or role not in valid_roles:
continue
if text:
message = {"role": role, "parts": [{"type": "text", "content": text}]}
input_messages.append(message)
return serialize_json_data(input_messages)
except Exception:
return serialize_json_data([])
def format_output_messages(outputs: Mapping[str, Any]) -> str:
try:
if not isinstance(outputs, dict):
return serialize_json_data([])
text = outputs.get("text", "")
finish_reason = outputs.get("finish_reason", "")
if not text:
return serialize_json_data([])
valid_finish_reasons = {"stop", "length", "content_filter", "tool_call", "error"}
if finish_reason not in valid_finish_reasons:
finish_reason = "stop"
output_message = {
"role": "assistant",
"parts": [{"type": "text", "content": text}],
"finish_reason": finish_reason,
}
return serialize_json_data([output_message])
except Exception:
return serialize_json_data([])

View File

@@ -0,0 +1,719 @@
import json
import logging
import os
import traceback
from datetime import datetime, timedelta
from typing import Any, Union, cast
from urllib.parse import urlparse
from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.util.types import AttributeValue
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]:
"""Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix."""
try:
# Choose the appropriate exporter based on config type
exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter]
# Inspect the provided endpoint to determine its structure
parsed = urlparse(arize_phoenix_config.endpoint)
base_endpoint = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/")
if isinstance(arize_phoenix_config, ArizeConfig):
arize_endpoint = f"{base_endpoint}/v1"
arize_headers = {
"api_key": arize_phoenix_config.api_key or "",
"space_id": arize_phoenix_config.space_id or "",
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
}
exporter = GrpcOTLPSpanExporter(
endpoint=arize_endpoint,
headers=arize_headers,
timeout=30,
)
else:
phoenix_endpoint = f"{base_endpoint}{path}/v1/traces"
phoenix_headers = {
"api_key": arize_phoenix_config.api_key or "",
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
}
exporter = HttpOTLPSpanExporter(
endpoint=phoenix_endpoint,
headers=phoenix_headers,
timeout=30,
)
attributes = {
"openinference.project.name": arize_phoenix_config.project or "",
"model_id": arize_phoenix_config.project or "",
}
resource = Resource(attributes=attributes)
provider = trace_sdk.TracerProvider(resource=resource)
processor = SimpleSpanProcessor(
exporter,
)
provider.add_span_processor(processor)
# Create a named tracer instead of setting the global provider
tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}"
logger.info("[Arize/Phoenix] Created tracer with name: %s", tracer_name)
return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor
except Exception as e:
logger.error("[Arize/Phoenix] Failed to setup the tracer: %s", str(e), exc_info=True)
raise
def datetime_to_nanos(dt: datetime | None) -> int:
"""Convert datetime to nanoseconds since epoch. If None, use current time."""
if dt is None:
dt = datetime.now()
return int(dt.timestamp() * 1_000_000_000)
def error_to_string(error: Exception | str | None) -> str:
"""Convert an error to a string with traceback information."""
error_message = "Empty Stack Trace"
if error:
if isinstance(error, Exception):
string_stacktrace = "".join(traceback.format_exception(error))
error_message = f"{error.__class__.__name__}: {error}\n\n{string_stacktrace}"
else:
error_message = str(error)
return error_message
def set_span_status(current_span: Span, error: Exception | str | None = None):
"""Set the status of the current span based on the presence of an error."""
if error:
error_string = error_to_string(error)
current_span.set_status(Status(StatusCode.ERROR, error_string))
if isinstance(error, Exception):
current_span.record_exception(error)
else:
exception_type = error.__class__.__name__
exception_message = str(error)
if not exception_message:
exception_message = repr(error)
attributes: dict[str, AttributeValue] = {
OTELSpanAttributes.EXCEPTION_TYPE: exception_type,
OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message,
OTELSpanAttributes.EXCEPTION_ESCAPED: False,
OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string,
}
current_span.add_event(name="exception", attributes=attributes)
else:
current_span.set_status(Status(StatusCode.OK))
def safe_json_dumps(obj: Any) -> str:
"""A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded."""
return json.dumps(obj, default=str, ensure_ascii=False)
class ArizePhoenixDataTrace(BaseTraceInstance):
def __init__(
self,
arize_phoenix_config: ArizeConfig | PhoenixConfig,
):
super().__init__(arize_phoenix_config)
import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
self.arize_phoenix_config = arize_phoenix_config
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
self.project = arize_phoenix_config.project
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.propagator = TraceContextTextMapPropagator()
self.dify_trace_ids: set[str] = set()
def trace(self, trace_info: BaseTraceInfo):
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
logger.info("[Arize/Phoenix] Trace Entity Type: %s", type(trace_info))
try:
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
except Exception as e:
logger.error("[Arize/Phoenix] Trace Entity Error: %s", str(e), exc_info=True)
raise
def workflow_trace(self, trace_info: WorkflowTraceInfo):
workflow_metadata = {
"workflow_run_id": trace_info.workflow_run_id or "",
"message_id": trace_info.message_id or "",
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
"status": trace_info.workflow_run_status or "",
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens or 0,
}
workflow_metadata.update(trace_info.metadata)
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
workflow_span = self.tracer.start_span(
name=TraceTaskName.WORKFLOW_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
)
# Through workflow_run_id, get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
)
try:
for node_execution in workflow_node_executions:
tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
inputs_value = node_execution.inputs or {}
outputs_value = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
process_data = node_execution.process_data or {}
execution_metadata = node_execution.metadata or {}
node_metadata = {str(k): v for k, v in execution_metadata.items()}
node_metadata.update(
{
"node_id": node_execution.id,
"node_type": node_execution.node_type,
"node_status": node_execution.status,
"tenant_id": tenant_id,
"app_id": app_id,
"app_name": node_execution.title,
"status": node_execution.status,
"level": "ERROR" if node_execution.status == "failed" else "DEFAULT",
}
)
# Determine the correct span kind based on node type
span_kind = OpenInferenceSpanKindValues.CHAIN
if node_execution.node_type == "llm":
span_kind = OpenInferenceSpanKindValues.LLM
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
node_metadata["ls_provider"] = provider
if model:
node_metadata["ls_model_name"] = model
usage_data = (
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
)
if usage_data:
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
elif node_execution.node_type == "dataset_retrieval":
span_kind = OpenInferenceSpanKindValues.RETRIEVER
elif node_execution.node_type == "tool":
span_kind = OpenInferenceSpanKindValues.TOOL
else:
span_kind = OpenInferenceSpanKindValues.CHAIN
workflow_span_context = set_span_in_context(workflow_span)
node_span = self.tracer.start_span(
name=node_execution.node_type,
attributes={
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(created_at),
context=workflow_span_context,
)
try:
if node_execution.node_type == "llm":
llm_attributes: dict[str, Any] = {
SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
}
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
llm_attributes[SpanAttributes.LLM_PROVIDER] = provider
if model:
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model
usage_data = (
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
)
if usage_data:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0)
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = usage_data.get("prompt_tokens", 0)
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = usage_data.get(
"completion_tokens", 0
)
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
node_span.set_attributes(llm_attributes)
finally:
if node_execution.status == "failed":
set_span_status(node_span, node_execution.error)
else:
set_span_status(node_span)
node_span.end(end_time=datetime_to_nanos(finished_at))
finally:
if trace_info.error:
set_span_status(workflow_span, trace_info.error)
else:
set_span_status(workflow_span)
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def message_trace(self, trace_info: MessageTraceInfo):
if trace_info.message_data is None:
return
file_list = cast(list[str], trace_info.file_list) or []
message_file_data: MessageFile | None = trace_info.message_file_data
if message_file_data is not None:
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
message_metadata = {
"message_id": trace_info.message_id or "",
"conversation_mode": str(trace_info.conversation_mode or ""),
"user_id": trace_info.message_data.from_account_id or "",
"file_list": json.dumps(file_list),
"status": trace_info.message_data.status or "",
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens or 0,
"prompt_tokens": trace_info.message_tokens or 0,
"completion_tokens": trace_info.answer_tokens or 0,
"ls_provider": trace_info.message_data.model_provider or "",
"ls_model_name": trace_info.message_data.model_id or "",
}
message_metadata.update(trace_info.metadata)
# Add end user data if available
if trace_info.message_data.from_end_user_id:
end_user_data: EndUser | None = (
db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first()
)
if end_user_data is not None:
message_metadata["end_user_id"] = end_user_data.session_id
attributes = {
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
}
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
message_span = self.tracer.start_span(
name=TraceTaskName.MESSAGE_TRACE.value,
attributes=attributes,
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
)
try:
# Convert outputs to string based on type
if isinstance(trace_info.outputs, dict | list):
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
elif isinstance(trace_info.outputs, str):
outputs_str = trace_info.outputs
else:
outputs_str = str(trace_info.outputs)
llm_attributes = {
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: outputs_str,
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
}
llm_attributes.update(self._construct_llm_attributes(trace_info.inputs))
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = trace_info.message_tokens
if trace_info.answer_tokens is not None and trace_info.answer_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = trace_info.answer_tokens
if trace_info.message_data.model_id is not None:
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = trace_info.message_data.model_id
if trace_info.message_data.model_provider is not None:
llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider
if trace_info.message_data and trace_info.message_data.message_metadata:
metadata_dict = json.loads(trace_info.message_data.message_metadata)
if model_params := metadata_dict.get("model_parameters"):
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
message_span_context = set_span_in_context(message_span)
llm_span = self.tracer.start_span(
name="llm",
attributes=llm_attributes,
start_time=datetime_to_nanos(trace_info.start_time),
context=message_span_context,
)
try:
if trace_info.message_data.error:
set_span_status(llm_span, trace_info.message_data.error)
else:
set_span_status(llm_span)
finally:
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
finally:
if trace_info.error:
set_span_status(message_span, trace_info.error)
else:
set_span_status(message_span)
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
metadata = {
"message_id": trace_info.message_id,
"tool_name": "moderation",
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
}
metadata.update(trace_info.metadata)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
span = self.tracer.start_span(
name=TraceTaskName.MODERATION_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(
{
"action": trace_info.action,
"flagged": trace_info.flagged,
"preset_response": trace_info.preset_response,
"inputs": trace_info.inputs,
},
ensure_ascii=False,
),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
)
try:
if trace_info.message_data.error:
set_span_status(span, trace_info.message_data.error)
else:
set_span_status(span)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
metadata = {
"message_id": trace_info.message_id,
"tool_name": "suggested_question",
"status": trace_info.status,
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens,
"ls_provider": trace_info.model_provider or "",
"ls_model_name": trace_info.model_id or "",
}
metadata.update(trace_info.metadata)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
span = self.tracer.start_span(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
},
start_time=datetime_to_nanos(start_time),
context=root_span_context,
)
try:
if trace_info.error:
set_span_status(span, trace_info.error)
else:
set_span_status(span)
finally:
span.end(end_time=datetime_to_nanos(end_time))
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
metadata = {
"message_id": trace_info.message_id,
"tool_name": "dataset_retrieval",
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
"ls_provider": trace_info.message_data.model_provider or "",
"ls_model_name": trace_info.message_data.model_id or "",
}
metadata.update(trace_info.metadata)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
span = self.tracer.start_span(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
"start_time": start_time.isoformat() if start_time else "",
"end_time": end_time.isoformat() if end_time else "",
},
start_time=datetime_to_nanos(start_time),
context=root_span_context,
)
try:
if trace_info.message_data.error:
set_span_status(span, trace_info.message_data.error)
else:
set_span_status(span)
finally:
span.end(end_time=datetime_to_nanos(end_time))
def tool_trace(self, trace_info: ToolTraceInfo):
if trace_info.message_data is None:
logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.")
return
metadata = {
"message_id": trace_info.message_id,
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
}
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
tool_params_str = (
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
if isinstance(trace_info.tool_parameters, dict)
else str(trace_info.tool_parameters)
)
span = self.tracer.start_span(
name=trace_info.tool_name,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.TOOL_NAME: trace_info.tool_name,
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
)
try:
if trace_info.error:
set_span_status(span, trace_info.error)
else:
set_span_status(span)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
if trace_info.message_data is None:
return
metadata = {
"project_name": self.project,
"message_id": trace_info.message_id,
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
}
metadata.update(trace_info.metadata)
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
span = self.tracer.start_span(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
"start_time": trace_info.start_time.isoformat() if trace_info.start_time else "",
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
)
try:
if trace_info.message_data.error:
set_span_status(span, trace_info.message_data.error)
else:
set_span_status(span)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def ensure_root_span(self, dify_trace_id: str | None):
"""Ensure a unique root span exists for the given Dify trace ID."""
if str(dify_trace_id) not in self.dify_trace_ids:
self.carrier: dict[str, str] = {}
root_span = self.tracer.start_span(name="Dify")
root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value)
root_span.set_attribute("dify_project_name", str(self.project))
root_span.set_attribute("dify_trace_id", str(dify_trace_id))
with use_span(root_span, end_on_exit=False):
self.propagator.inject(carrier=self.carrier)
set_span_status(root_span)
root_span.end()
self.dify_trace_ids.add(str(dify_trace_id))
def api_check(self):
try:
with self.tracer.start_span("api_check") as span:
span.set_attribute("test", "true")
return True
except Exception as e:
logger.info("[Arize/Phoenix] API check failed: %s", str(e), exc_info=True)
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
def get_project_url(self):
try:
if self.arize_phoenix_config.endpoint == "https://otlp.arize.com":
return "https://app.arize.com/"
else:
return f"{self.arize_phoenix_config.endpoint}/projects/"
except Exception as e:
logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True)
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
"""Helper method to construct LLM attributes with passed prompts."""
attributes = {}
if isinstance(prompts, list):
for i, msg in enumerate(prompts):
if isinstance(msg, dict):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user")
# todo: handle assistant and tool role messages, as they don't always
# have a text field, but may have a tool_calls field instead
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
elif isinstance(prompts, dict):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts)
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
elif isinstance(prompts, str):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
return attributes

View File

@@ -0,0 +1,67 @@
from abc import ABC, abstractmethod
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.entities.trace_entity import BaseTraceInfo
from extensions.ext_database import db
from models import Account, App, TenantAccountJoin
class BaseTraceInstance(ABC):
"""
Base trace instance for ops trace services
"""
@abstractmethod
def __init__(self, trace_config: BaseTracingConfig):
"""
Abstract initializer for the trace instance.
Distribute trace tasks by matching entities
"""
self.trace_config = trace_config
@abstractmethod
def trace(self, trace_info: BaseTraceInfo):
"""
Abstract method to trace activities.
Subclasses must implement specific tracing logic for activities.
"""
...
def get_service_account_with_tenant(self, app_id: str) -> Account:
"""
Get service account for an app and set up its tenant.
Args:
app_id: The ID of the app
Returns:
Account: The service account with tenant set up
Raises:
ValueError: If app, creator account or tenant cannot be found
"""
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")
service_account.set_tenant_id(current_tenant.tenant_id)
return service_account

View File

View File

@@ -0,0 +1,271 @@
from enum import StrEnum
from pydantic import BaseModel, ValidationInfo, field_validator
from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path
class TracingProviderEnum(StrEnum):
ARIZE = "arize"
PHOENIX = "phoenix"
LANGFUSE = "langfuse"
LANGSMITH = "langsmith"
OPIK = "opik"
WEAVE = "weave"
ALIYUN = "aliyun"
MLFLOW = "mlflow"
DATABRICKS = "databricks"
TENCENT = "tencent"
class BaseTracingConfig(BaseModel):
"""
Base model class for tracing configurations
"""
@classmethod
def validate_endpoint_url(cls, v: str, default_url: str) -> str:
"""
Common endpoint URL validation logic
Args:
v: URL value to validate
default_url: Default URL to use if input is None or empty
Returns:
Validated and normalized URL
"""
return validate_url(v, default_url)
@classmethod
def validate_project_field(cls, v: str, default_name: str) -> str:
"""
Common project name validation logic
Args:
v: Project name to validate
default_name: Default name to use if input is None or empty
Returns:
Validated project name
"""
return validate_project_name(v, default_name)
class ArizeConfig(BaseTracingConfig):
"""
Model class for Arize tracing config.
"""
api_key: str | None = None
space_id: str | None = None
project: str | None = None
endpoint: str = "https://otlp.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
class PhoenixConfig(BaseTracingConfig):
"""
Model class for Phoenix tracing config.
"""
api_key: str | None = None
project: str | None = None
endpoint: str = "https://app.phoenix.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://app.phoenix.arize.com")
class LangfuseConfig(BaseTracingConfig):
"""
Model class for Langfuse tracing config.
"""
public_key: str
secret_key: str
host: str = "https://api.langfuse.com"
@field_validator("host")
@classmethod
def host_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://api.langfuse.com")
class LangSmithConfig(BaseTracingConfig):
"""
Model class for Langsmith tracing config.
"""
api_key: str
project: str
endpoint: str = "https://api.smith.langchain.com"
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# LangSmith only allows HTTPS
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
class OpikConfig(BaseTracingConfig):
"""
Model class for Opik tracing config.
"""
api_key: str | None = None
project: str | None = None
workspace: str | None = None
url: str = "https://www.comet.com/opik/api/"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "Default Project")
@field_validator("url")
@classmethod
def url_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
class WeaveConfig(BaseTracingConfig):
"""
Model class for Weave tracing config.
"""
api_key: str
entity: str | None = None
project: str
endpoint: str = "https://trace.wandb.ai"
host: str | None = None
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# Weave only allows HTTPS for endpoint
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
@field_validator("host")
@classmethod
def host_validator(cls, v, info: ValidationInfo):
if v is not None and v.strip() != "":
return validate_url(v, v, allowed_schemes=("https", "http"))
return v
class AliyunConfig(BaseTracingConfig):
"""
Model class for Aliyun tracing config.
"""
app_name: str = "dify_app"
license_key: str
endpoint: str
@field_validator("app_name")
@classmethod
def app_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
@field_validator("license_key")
@classmethod
def license_key_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("License key cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# aliyun uses two URL formats, which may include a URL path
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
class TencentConfig(BaseTracingConfig):
"""
Tencent APM tracing config
"""
token: str
endpoint: str
service_name: str
@field_validator("token")
@classmethod
def token_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("Token cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
@field_validator("service_name")
@classmethod
def service_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
class MLflowConfig(BaseTracingConfig):
"""
Model class for MLflow tracing config.
"""
tracking_uri: str = "http://localhost:5000"
experiment_id: str = "0" # Default experiment id in MLflow is 0
username: str | None = None
password: str | None = None
@field_validator("tracking_uri")
@classmethod
def tracking_uri_validator(cls, v, info: ValidationInfo):
if isinstance(v, str) and v.startswith("databricks"):
raise ValueError(
"Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
)
return validate_url_with_path(v, "http://localhost:5000")
@field_validator("experiment_id")
@classmethod
def experiment_id_validator(cls, v, info: ValidationInfo):
return validate_integer_id(v)
class DatabricksConfig(BaseTracingConfig):
"""
Model class for Databricks (Databricks-managed MLflow) tracing config.
"""
experiment_id: str
host: str
client_id: str | None = None
client_secret: str | None = None
personal_access_token: str | None = None
@field_validator("experiment_id")
@classmethod
def experiment_id_validator(cls, v, info: ValidationInfo):
return validate_integer_id(v)
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

View File

@@ -0,0 +1,143 @@
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any, Union
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
class BaseTraceInfo(BaseModel):
message_id: str | None = None
message_data: Any | None = None
inputs: Union[str, dict[str, Any], list] | None = None
outputs: Union[str, dict[str, Any], list] | None = None
start_time: datetime | None = None
end_time: datetime | None = None
metadata: dict[str, Any]
trace_id: str | None = None
@field_validator("inputs", "outputs")
@classmethod
def ensure_type(cls, v):
if v is None:
return None
if isinstance(v, str | dict | list):
return v
return ""
model_config = ConfigDict(protected_namespaces=())
@field_serializer("start_time", "end_time")
def serialize_datetime(self, dt: datetime | None) -> str | None:
if dt is None:
return None
return dt.isoformat()
class WorkflowTraceInfo(BaseTraceInfo):
workflow_data: Any = None
conversation_id: str | None = None
workflow_app_log_id: str | None = None
workflow_id: str
tenant_id: str
workflow_run_id: str
workflow_run_elapsed_time: Union[int, float]
workflow_run_status: str
workflow_run_inputs: Mapping[str, Any]
workflow_run_outputs: Mapping[str, Any]
workflow_run_version: str
error: str | None = None
total_tokens: int
file_list: list[str]
query: str
metadata: dict[str, Any]
class MessageTraceInfo(BaseTraceInfo):
conversation_model: str
message_tokens: int
answer_tokens: int
total_tokens: int
error: str | None = None
file_list: Union[str, dict[str, Any], list] | None = None
message_file_data: Any | None = None
conversation_mode: str
gen_ai_server_time_to_first_token: float | None = None
llm_streaming_time_to_generate: float | None = None
is_streaming_request: bool = False
class ModerationTraceInfo(BaseTraceInfo):
flagged: bool
action: str
preset_response: str
query: str
class SuggestedQuestionTraceInfo(BaseTraceInfo):
total_tokens: int
status: str | None = None
error: str | None = None
from_account_id: str | None = None
agent_based: bool | None = None
from_source: str | None = None
model_provider: str | None = None
model_id: str | None = None
suggested_question: list[str]
level: str
status_message: str | None = None
workflow_run_id: str | None = None
model_config = ConfigDict(protected_namespaces=())
class DatasetRetrievalTraceInfo(BaseTraceInfo):
documents: Any = None
error: str | None = None
class ToolTraceInfo(BaseTraceInfo):
tool_name: str
tool_inputs: dict[str, Any]
tool_outputs: str
metadata: dict[str, Any]
message_file_data: Any = None
error: str | None = None
tool_config: dict[str, Any]
time_cost: Union[int, float]
tool_parameters: dict[str, Any]
file_url: Union[str, None, list] = None
class GenerateNameTraceInfo(BaseTraceInfo):
conversation_id: str | None = None
tenant_id: str
class TaskData(BaseModel):
app_id: str
trace_info_type: str
trace_info: Any = None
trace_info_info_map = {
"WorkflowTraceInfo": WorkflowTraceInfo,
"MessageTraceInfo": MessageTraceInfo,
"ModerationTraceInfo": ModerationTraceInfo,
"SuggestedQuestionTraceInfo": SuggestedQuestionTraceInfo,
"DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo,
"ToolTraceInfo": ToolTraceInfo,
"GenerateNameTraceInfo": GenerateNameTraceInfo,
}
class TraceTaskName(StrEnum):
CONVERSATION_TRACE = "conversation"
WORKFLOW_TRACE = "workflow"
MESSAGE_TRACE = "message"
MODERATION_TRACE = "moderation"
SUGGESTED_QUESTION_TRACE = "suggested_question"
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
TOOL_TRACE = "tool"
GENERATE_NAME_TRACE = "generate_conversation_name"
DATASOURCE_TRACE = "datasource"

View File

@@ -0,0 +1,283 @@
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.ops.utils import replace_text_with_content
def validate_input_output(v, field_name):
"""
Validate input output
:param v:
:param field_name:
:return:
"""
if v == {} or v is None:
return v
if isinstance(v, str):
return [
{
"role": "assistant" if field_name == "output" else "user",
"content": v,
}
]
elif isinstance(v, list):
if len(v) > 0 and isinstance(v[0], dict):
v = replace_text_with_content(data=v)
return v
else:
return [
{
"role": "assistant" if field_name == "output" else "user",
"content": str(v),
}
]
return v
class LevelEnum(StrEnum):
DEBUG = "DEBUG"
WARNING = "WARNING"
ERROR = "ERROR"
DEFAULT = "DEFAULT"
class LangfuseTrace(BaseModel):
"""
Langfuse trace model
"""
id: str | None = Field(
default=None,
description="The id of the trace can be set, defaults to a random id. Used to link traces to external systems "
"or when creating a distributed trace. Traces are upserted on id.",
)
name: str | None = Field(
default=None,
description="Identifier of the trace. Useful for sorting/filtering in the UI.",
)
input: Union[str, dict[str, Any], list, None] | None = Field(
default=None, description="The input of the trace. Can be any JSON object."
)
output: Union[str, dict[str, Any], list, None] | None = Field(
default=None, description="The output of the trace. Can be any JSON object."
)
metadata: dict[str, Any] | None = Field(
default=None,
description="Additional metadata of the trace. Can be any JSON object. Metadata is merged when being updated "
"via the API.",
)
user_id: str | None = Field(
default=None,
description="The id of the user that triggered the execution. Used to provide user-level analytics.",
)
session_id: str | None = Field(
default=None,
description="Used to group multiple traces into a session in Langfuse. Use your own session/thread identifier.",
)
version: str | None = Field(
default=None,
description="The version of the trace type. Used to understand how changes to the trace type affect metrics. "
"Useful in debugging.",
)
release: str | None = Field(
default=None,
description="The release identifier of the current deployment. Used to understand how changes of different "
"deployments affect metrics. Useful in debugging.",
)
tags: list[str] | None = Field(
default=None,
description="Tags are used to categorize or label traces. Traces can be filtered by tags in the UI and GET "
"API. Tags can also be changed in the UI. Tags are merged and never deleted via the API.",
)
public: bool | None = Field(
default=None,
description="You can make a trace public to share it via a public link. This allows others to view the trace "
"without needing to log in or be members of your Langfuse project.",
)
@field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
return validate_input_output(v, field_name)
class LangfuseSpan(BaseModel):
"""
Langfuse span model
"""
id: str | None = Field(
default=None,
description="The id of the span can be set, otherwise a random id is generated. Spans are upserted on id.",
)
session_id: str | None = Field(
default=None,
description="Used to group multiple spans into a session in Langfuse. Use your own session/thread identifier.",
)
trace_id: str | None = Field(
default=None,
description="The id of the trace the span belongs to. Used to link spans to traces.",
)
user_id: str | None = Field(
default=None,
description="The id of the user that triggered the execution. Used to provide user-level analytics.",
)
start_time: datetime | str | None = Field(
default_factory=datetime.now,
description="The time at which the span started, defaults to the current time.",
)
end_time: datetime | str | None = Field(
default=None,
description="The time at which the span ended. Automatically set by span.end().",
)
name: str | None = Field(
default=None,
description="Identifier of the span. Useful for sorting/filtering in the UI.",
)
metadata: dict[str, Any] | None = Field(
default=None,
description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated "
"via the API.",
)
level: str | None = Field(
default=None,
description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of "
"traces with elevated error levels and for highlighting in the UI.",
)
status_message: str | None = Field(
default=None,
description="The status message of the span. Additional field for context of the event. E.g. the error "
"message of an error event.",
)
input: Union[str, Mapping[str, Any], list, None] | None = Field(
default=None, description="The input of the span. Can be any JSON object."
)
output: Union[str, Mapping[str, Any], list, None] | None = Field(
default=None, description="The output of the span. Can be any JSON object."
)
version: str | None = Field(
default=None,
description="The version of the span type. Used to understand how changes to the span type affect metrics. "
"Useful in debugging.",
)
parent_observation_id: str | None = Field(
default=None,
description="The id of the observation the span belongs to. Used to link spans to observations.",
)
@field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
return validate_input_output(v, field_name)
class UnitEnum(StrEnum):
CHARACTERS = "CHARACTERS"
TOKENS = "TOKENS"
SECONDS = "SECONDS"
MILLISECONDS = "MILLISECONDS"
IMAGES = "IMAGES"
class GenerationUsage(BaseModel):
promptTokens: int | None = None
completionTokens: int | None = None
total: int | None = None
input: int | None = None
output: int | None = None
unit: UnitEnum | None = None
inputCost: float | None = None
outputCost: float | None = None
totalCost: float | None = None
@field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
return validate_input_output(v, field_name)
class LangfuseGeneration(BaseModel):
id: str | None = Field(
default=None,
description="The id of the generation can be set, defaults to random id.",
)
trace_id: str | None = Field(
default=None,
description="The id of the trace the generation belongs to. Used to link generations to traces.",
)
parent_observation_id: str | None = Field(
default=None,
description="The id of the observation the generation belongs to. Used to link generations to observations.",
)
name: str | None = Field(
default=None,
description="Identifier of the generation. Useful for sorting/filtering in the UI.",
)
start_time: datetime | str | None = Field(
default_factory=datetime.now,
description="The time at which the generation started, defaults to the current time.",
)
completion_start_time: datetime | str | None = Field(
default=None,
description="The time at which the completion started (streaming). Set it to get latency analytics broken "
"down into time until completion started and completion duration.",
)
end_time: datetime | str | None = Field(
default=None,
description="The time at which the generation ended. Automatically set by generation.end().",
)
model: str | None = Field(default=None, description="The name of the model used for the generation.")
model_parameters: dict[str, Any] | None = Field(
default=None,
description="The parameters of the model used for the generation; can be any key-value pairs.",
)
input: Any | None = Field(
default=None,
description="The prompt used for the generation. Can be any string or JSON object.",
)
output: Any | None = Field(
default=None,
description="The completion generated by the model. Can be any string or JSON object.",
)
usage: GenerationUsage | None = Field(
default=None,
description="The usage object supports the OpenAi structure with tokens and a more generic version with "
"detailed costs and units.",
)
metadata: dict[str, Any] | None = Field(
default=None,
description="Additional metadata of the generation. Can be any JSON object. Metadata is merged when being "
"updated via the API.",
)
level: LevelEnum | None = Field(
default=None,
description="The level of the generation. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering "
"of traces with elevated error levels and for highlighting in the UI.",
)
status_message: str | None = Field(
default=None,
description="The status message of the generation. Additional field for context of the event. E.g. the error "
"message of an error event.",
)
version: str | None = Field(
default=None,
description="The version of the generation type. Used to understand how changes to the span type affect "
"metrics. Useful in debugging.",
)
model_config = ConfigDict(protected_namespaces=())
@field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
return validate_input_output(v, field_name)

View File

@@ -0,0 +1,452 @@
import logging
import os
from datetime import datetime, timedelta
from langfuse import Langfuse
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
GenerationUsage,
LangfuseGeneration,
LangfuseSpan,
LangfuseTrace,
LevelEnum,
UnitEnum,
)
from core.ops.utils import filter_none_values
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.enums import NodeType
from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
from models.enums import MessageStatus
logger = logging.getLogger(__name__)
class LangFuseDataTrace(BaseTraceInstance):
def __init__(
self,
langfuse_config: LangfuseConfig,
):
super().__init__(langfuse_config)
self.langfuse_client = Langfuse(
public_key=langfuse_config.public_key,
secret_key=langfuse_config.secret_key,
host=langfuse_config.host,
)
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.trace_id or trace_info.workflow_run_id
user_id = trace_info.metadata.get("user_id")
metadata = trace_info.metadata
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
if trace_info.message_id:
trace_id = trace_info.trace_id or trace_info.message_id
name = TraceTaskName.MESSAGE_TRACE
trace_data = LangfuseTrace(
id=trace_id,
user_id=user_id,
name=name,
input=dict(trace_info.workflow_run_inputs),
output=dict(trace_info.workflow_run_outputs),
metadata=metadata,
session_id=trace_info.conversation_id,
tags=["message", "workflow"],
version=trace_info.workflow_run_version,
)
self.add_trace(langfuse_trace_data=trace_data)
workflow_span_data = LangfuseSpan(
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE,
input=dict(trace_info.workflow_run_inputs),
output=dict(trace_info.workflow_run_outputs),
trace_id=trace_id,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
metadata=metadata,
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
status_message=trace_info.error or "",
)
self.add_span(langfuse_span_data=workflow_span_data)
else:
trace_data = LangfuseTrace(
id=trace_id,
user_id=user_id,
name=TraceTaskName.WORKFLOW_TRACE,
input=dict(trace_info.workflow_run_inputs),
output=dict(trace_info.workflow_run_outputs),
metadata=metadata,
session_id=trace_info.conversation_id,
tags=["workflow"],
version=trace_info.workflow_run_version,
)
self.add_trace(langfuse_trace_data=trace_data)
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
)
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id
tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}
outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata or {}
metadata = {str(k): v for k, v in execution_metadata.items()}
metadata.update(
{
"workflow_run_id": trace_info.workflow_run_id,
"node_execution_id": node_execution_id,
"tenant_id": tenant_id,
"app_id": app_id,
"node_name": node_name,
"node_type": node_type,
"status": status,
}
)
process_data = node_execution.process_data or {}
model_provider = process_data.get("model_provider", None)
model_name = process_data.get("model_name", None)
if model_provider is not None and model_name is not None:
metadata.update(
{
"model_provider": model_provider,
"model_name": model_name,
}
)
# add generation span
if process_data and process_data.get("model_mode") == "chat":
total_token = metadata.get("total_tokens", 0)
prompt_tokens = 0
completion_tokens = 0
try:
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
# add generation
generation_usage = GenerationUsage(
input=prompt_tokens,
output=completion_tokens,
total=total_token,
unit=UnitEnum.TOKENS,
)
node_generation_data = LangfuseGeneration(
id=node_execution_id,
name=node_name,
trace_id=trace_id,
model=process_data.get("model_name"),
start_time=created_at,
end_time=finished_at,
input=inputs,
output=outputs,
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error or "",
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
usage=generation_usage,
)
self.add_generation(langfuse_generation_data=node_generation_data)
# add normal span
else:
span_data = LangfuseSpan(
id=node_execution_id,
name=node_name,
input=inputs,
output=outputs,
trace_id=trace_id,
start_time=created_at,
end_time=finished_at,
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error or "",
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
)
self.add_span(langfuse_span_data=span_data)
def message_trace(self, trace_info: MessageTraceInfo, **kwargs):
# get message file data
file_list = trace_info.file_list
metadata = trace_info.metadata
message_data = trace_info.message_data
if message_data is None:
return
message_id = message_data.id
user_id = message_data.from_account_id
if message_data.from_end_user_id:
end_user_data: EndUser | None = (
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
user_id = end_user_data.session_id
metadata["user_id"] = user_id
trace_id = trace_info.trace_id or message_id
trace_data = LangfuseTrace(
id=trace_id,
user_id=user_id,
name=TraceTaskName.MESSAGE_TRACE,
input={
"message": trace_info.inputs,
"files": file_list,
"message_tokens": trace_info.message_tokens,
"answer_tokens": trace_info.answer_tokens,
"total_tokens": trace_info.total_tokens,
"error": trace_info.error,
"provider_response_latency": message_data.provider_response_latency,
"created_at": trace_info.start_time,
},
output=trace_info.outputs,
metadata=metadata,
session_id=message_data.conversation_id,
tags=["message", str(trace_info.conversation_mode)],
version=None,
release=None,
public=None,
)
self.add_trace(langfuse_trace_data=trace_data)
# add generation
generation_usage = GenerationUsage(
input=trace_info.message_tokens,
output=trace_info.answer_tokens,
total=trace_info.total_tokens,
unit=UnitEnum.TOKENS,
totalCost=message_data.total_price,
)
langfuse_generation_data = LangfuseGeneration(
name="llm",
trace_id=trace_id,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
model=message_data.model_id,
input=trace_info.inputs,
output=message_data.answer,
metadata=metadata,
level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
status_message=message_data.error or "",
usage=generation_usage,
)
self.add_generation(langfuse_generation_data)
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
span_data = LangfuseSpan(
name=TraceTaskName.MODERATION_TRACE,
input=trace_info.inputs,
output={
"action": trace_info.action,
"flagged": trace_info.flagged,
"preset_response": trace_info.preset_response,
"inputs": trace_info.inputs,
},
trace_id=trace_info.trace_id or trace_info.message_id,
start_time=trace_info.start_time or trace_info.message_data.created_at,
end_time=trace_info.end_time or trace_info.message_data.created_at,
metadata=trace_info.metadata,
)
self.add_span(langfuse_span_data=span_data)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
generation_usage = GenerationUsage(
total=len(str(trace_info.suggested_question)),
input=len(trace_info.inputs) if trace_info.inputs else 0,
output=len(trace_info.suggested_question),
unit=UnitEnum.CHARACTERS,
)
generation_data = LangfuseGeneration(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
input=trace_info.inputs,
output=str(trace_info.suggested_question),
trace_id=trace_info.trace_id or trace_info.message_id,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
metadata=trace_info.metadata,
level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
status_message=message_data.error or "",
usage=generation_usage,
)
self.add_generation(langfuse_generation_data=generation_data)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
dataset_retrieval_span_data = LangfuseSpan(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
input=trace_info.inputs,
output={"documents": trace_info.documents},
trace_id=trace_info.trace_id or trace_info.message_id,
start_time=trace_info.start_time or trace_info.message_data.created_at,
end_time=trace_info.end_time or trace_info.message_data.updated_at,
metadata=trace_info.metadata,
)
self.add_span(langfuse_span_data=dataset_retrieval_span_data)
def tool_trace(self, trace_info: ToolTraceInfo):
tool_span_data = LangfuseSpan(
name=trace_info.tool_name,
input=trace_info.tool_inputs,
output=trace_info.tool_outputs,
trace_id=trace_info.trace_id or trace_info.message_id,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
metadata=trace_info.metadata,
level=(LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR),
status_message=trace_info.error,
)
self.add_span(langfuse_span_data=tool_span_data)
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
name_generation_trace_data = LangfuseTrace(
name=TraceTaskName.GENERATE_NAME_TRACE,
input=trace_info.inputs,
output=trace_info.outputs,
user_id=trace_info.tenant_id,
metadata=trace_info.metadata,
session_id=trace_info.conversation_id,
)
self.add_trace(langfuse_trace_data=name_generation_trace_data)
name_generation_span_data = LangfuseSpan(
name=TraceTaskName.GENERATE_NAME_TRACE,
input=trace_info.inputs,
output=trace_info.outputs,
trace_id=trace_info.conversation_id,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
metadata=trace_info.metadata,
)
self.add_span(langfuse_span_data=name_generation_span_data)
def add_trace(self, langfuse_trace_data: LangfuseTrace | None = None):
format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
try:
self.langfuse_client.trace(**format_trace_data)
logger.debug("LangFuse Trace created successfully")
except Exception as e:
raise ValueError(f"LangFuse Failed to create trace: {str(e)}")
def add_span(self, langfuse_span_data: LangfuseSpan | None = None):
format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
try:
self.langfuse_client.span(**format_span_data)
logger.debug("LangFuse Span created successfully")
except Exception as e:
raise ValueError(f"LangFuse Failed to create span: {str(e)}")
def update_span(self, span, langfuse_span_data: LangfuseSpan | None = None):
format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
span.end(**format_span_data)
def add_generation(self, langfuse_generation_data: LangfuseGeneration | None = None):
format_generation_data = (
filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
)
try:
self.langfuse_client.generation(**format_generation_data)
logger.debug("LangFuse Generation created successfully")
except Exception as e:
raise ValueError(f"LangFuse Failed to create generation: {str(e)}")
def update_generation(self, generation, langfuse_generation_data: LangfuseGeneration | None = None):
format_generation_data = (
filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
)
generation.end(**format_generation_data)
def api_check(self):
try:
return self.langfuse_client.auth_check()
except Exception as e:
logger.debug("LangFuse API check failed: %s", str(e))
raise ValueError(f"LangFuse API check failed: {str(e)}")
def get_project_key(self):
try:
projects = self.langfuse_client.client.projects.get()
return projects.data[0].id
except Exception as e:
logger.debug("LangFuse get project key failed: %s", str(e))
raise ValueError(f"LangFuse get project key failed: {str(e)}")

View File

@@ -0,0 +1,142 @@
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any, Union
from pydantic import BaseModel, Field, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.ops.utils import replace_text_with_content
class LangSmithRunType(StrEnum):
tool = "tool"
chain = "chain"
llm = "llm"
retriever = "retriever"
embedding = "embedding"
prompt = "prompt"
parser = "parser"
class LangSmithTokenUsage(BaseModel):
input_tokens: int | None = None
output_tokens: int | None = None
total_tokens: int | None = None
class LangSmithMultiModel(BaseModel):
file_list: list[str] | None = Field(None, description="List of files")
class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
name: str | None = Field(..., description="Name of the run")
inputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Inputs of the run")
outputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Outputs of the run")
run_type: LangSmithRunType = Field(..., description="Type of the run")
start_time: datetime | str | None = Field(None, description="Start time of the run")
end_time: datetime | str | None = Field(None, description="End time of the run")
extra: dict[str, Any] | None = Field(None, description="Extra information of the run")
error: str | None = Field(None, description="Error message of the run")
serialized: dict[str, Any] | None = Field(None, description="Serialized data of the run")
parent_run_id: str | None = Field(None, description="Parent run ID")
events: list[dict[str, Any]] | None = Field(None, description="Events associated with the run")
tags: list[str] | None = Field(None, description="Tags associated with the run")
trace_id: str | None = Field(None, description="Trace ID associated with the run")
dotted_order: str | None = Field(None, description="Dotted order of the run")
id: str | None = Field(None, description="ID of the run")
session_id: str | None = Field(None, description="Session ID associated with the run")
session_name: str | None = Field(None, description="Session name associated with the run")
reference_example_id: str | None = Field(None, description="Reference example ID associated with the run")
input_attachments: dict[str, Any] | None = Field(None, description="Input attachments of the run")
output_attachments: dict[str, Any] | None = Field(None, description="Output attachments of the run")
@field_validator("inputs", "outputs")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
values = info.data
if v == {} or v is None:
return v
usage_metadata = {
"input_tokens": values.get("input_tokens", 0),
"output_tokens": values.get("output_tokens", 0),
"total_tokens": values.get("total_tokens", 0),
}
file_list = values.get("file_list", [])
if isinstance(v, str):
if field_name == "inputs":
return {
"messages": {
"role": "user",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif field_name == "outputs":
return {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif isinstance(v, list):
data = {}
if len(v) > 0 and isinstance(v[0], dict):
# rename text to content
v = replace_text_with_content(data=v)
if field_name == "inputs":
data = {
"messages": v,
}
elif field_name == "outputs":
data = {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
return data
else:
return {
"choices": {
"role": "ai" if field_name == "outputs" else "user",
"content": str(v),
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
if isinstance(v, dict):
v["usage_metadata"] = usage_metadata
v["file_list"] = file_list
return v
return v
@classmethod
@field_validator("start_time", "end_time")
def format_time(cls, v, info: ValidationInfo):
if not isinstance(v, datetime):
raise ValueError(f"{info.field_name} must be a datetime object")
else:
return v.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
class LangSmithRunUpdateModel(BaseModel):
run_id: str = Field(..., description="ID of the run")
trace_id: str | None = Field(None, description="Trace ID associated with the run")
dotted_order: str | None = Field(None, description="Dotted order of the run")
parent_run_id: str | None = Field(None, description="Parent run ID")
end_time: datetime | str | None = Field(None, description="End time of the run")
error: str | None = Field(None, description="Error message of the run")
inputs: dict[str, Any] | None = Field(None, description="Inputs of the run")
outputs: dict[str, Any] | None = Field(None, description="Outputs of the run")
events: list[dict[str, Any]] | None = Field(None, description="Events associated with the run")
tags: list[str] | None = Field(None, description="Tags associated with the run")
extra: dict[str, Any] | None = Field(None, description="Extra information of the run")
input_attachments: dict[str, Any] | None = Field(None, description="Input attachments of the run")
output_attachments: dict[str, Any] | None = Field(None, description="Output attachments of the run")

View File

@@ -0,0 +1,525 @@
import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import cast
from langsmith import Client
from langsmith.schemas import RunBase
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangSmithConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunModel,
LangSmithRunType,
LangSmithRunUpdateModel,
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
class LangSmithDataTrace(BaseTraceInstance):
def __init__(
self,
langsmith_config: LangSmithConfig,
):
super().__init__(langsmith_config)
self.langsmith_key = langsmith_config.api_key
self.project_name = langsmith_config.project
self.project_id = None
self.langsmith_client = Client(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint)
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
if trace_info.start_time is None:
trace_info.start_time = datetime.now()
message_dotted_order = (
generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
)
workflow_dotted_order = generate_dotted_order(
trace_info.workflow_run_id,
trace_info.workflow_data.created_at,
message_dotted_order,
)
metadata = trace_info.metadata
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
if trace_info.message_id:
message_run = LangSmithRunModel(
id=trace_info.message_id,
name=TraceTaskName.MESSAGE_TRACE,
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
run_type=LangSmithRunType.chain,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
extra={
"metadata": metadata,
},
tags=["message", "workflow"],
error=trace_info.error,
trace_id=trace_id,
dotted_order=message_dotted_order,
file_list=[],
serialized=None,
parent_run_id=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
)
self.add_run(message_run)
langsmith_run = LangSmithRunModel(
file_list=trace_info.file_list,
total_tokens=trace_info.total_tokens,
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE,
inputs=dict(trace_info.workflow_run_inputs),
run_type=LangSmithRunType.tool,
start_time=trace_info.workflow_data.created_at,
end_time=trace_info.workflow_data.finished_at,
outputs=dict(trace_info.workflow_run_outputs),
extra={
"metadata": metadata,
},
error=trace_info.error,
tags=["workflow"],
parent_run_id=trace_info.message_id or None,
trace_id=trace_id,
dotted_order=workflow_dotted_order,
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
)
self.add_run(langsmith_run)
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
)
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id
tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}
outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata or {}
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
metadata = {str(key): value for key, value in execution_metadata.items()}
metadata.update(
{
"workflow_run_id": trace_info.workflow_run_id,
"node_execution_id": node_execution_id,
"tenant_id": tenant_id,
"app_id": app_id,
"app_name": node_name,
"node_type": node_type,
"status": status,
}
)
process_data = node_execution.process_data or {}
if process_data and process_data.get("model_mode") == "chat":
run_type = LangSmithRunType.llm
metadata.update(
{
"ls_provider": process_data.get("model_provider", ""),
"ls_model_name": process_data.get("model_name", ""),
}
)
elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
run_type = LangSmithRunType.retriever
else:
run_type = LangSmithRunType.tool
prompt_tokens = 0
completion_tokens = 0
try:
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
langsmith_run = LangSmithRunModel(
total_tokens=node_total_tokens,
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
name=node_type,
inputs=inputs,
run_type=run_type,
start_time=created_at,
end_time=finished_at,
outputs=outputs,
file_list=trace_info.file_list,
extra={
"metadata": metadata,
},
parent_run_id=trace_info.workflow_run_id,
tags=["node_execution"],
id=node_execution_id,
trace_id=trace_id,
dotted_order=node_dotted_order,
error="",
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
)
self.add_run(langsmith_run)
def message_trace(self, trace_info: MessageTraceInfo):
# get message file data
file_list = cast(list[str], trace_info.file_list) or []
message_file_data: MessageFile | None = trace_info.message_file_data
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
metadata = trace_info.metadata
message_data = trace_info.message_data
if message_data is None:
return
message_id = message_data.id
user_id = message_data.from_account_id
metadata["user_id"] = user_id
if message_data.from_end_user_id:
end_user_data: EndUser | None = (
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id
metadata["end_user_id"] = end_user_id
message_run = LangSmithRunModel(
input_tokens=trace_info.message_tokens,
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
id=message_id,
name=TraceTaskName.MESSAGE_TRACE,
inputs=trace_info.inputs,
run_type=LangSmithRunType.chain,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
outputs=message_data.answer,
extra={"metadata": metadata},
tags=["message", str(trace_info.conversation_mode)],
error=trace_info.error,
file_list=file_list,
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
trace_id=trace_info.trace_id,
dotted_order=None,
parent_run_id=None,
)
self.add_run(message_run)
# create llm run parented to message run
llm_run = LangSmithRunModel(
input_tokens=trace_info.message_tokens,
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
name="llm",
inputs=trace_info.inputs,
run_type=LangSmithRunType.llm,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
outputs=message_data.answer,
extra={"metadata": metadata},
parent_run_id=message_id,
tags=["llm", str(trace_info.conversation_mode)],
error=trace_info.error,
file_list=file_list,
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
trace_id=trace_info.trace_id,
dotted_order=None,
id=str(uuid.uuid4()),
)
self.add_run(llm_run)
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
langsmith_run = LangSmithRunModel(
name=TraceTaskName.MODERATION_TRACE,
inputs=trace_info.inputs,
outputs={
"action": trace_info.action,
"flagged": trace_info.flagged,
"preset_response": trace_info.preset_response,
"inputs": trace_info.inputs,
},
run_type=LangSmithRunType.tool,
extra={"metadata": trace_info.metadata},
tags=["moderation"],
parent_run_id=trace_info.message_id,
start_time=trace_info.start_time or trace_info.message_data.created_at,
end_time=trace_info.end_time or trace_info.message_data.updated_at,
id=str(uuid.uuid4()),
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
trace_id=trace_info.trace_id,
dotted_order=None,
error="",
file_list=[],
)
self.add_run(langsmith_run)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
suggested_question_run = LangSmithRunModel(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
inputs=trace_info.inputs,
outputs=trace_info.suggested_question,
run_type=LangSmithRunType.tool,
extra={"metadata": trace_info.metadata},
tags=["suggested_question"],
parent_run_id=trace_info.message_id,
start_time=trace_info.start_time or message_data.created_at,
end_time=trace_info.end_time or message_data.updated_at,
id=str(uuid.uuid4()),
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
trace_id=trace_info.trace_id,
dotted_order=None,
error="",
file_list=[],
)
self.add_run(suggested_question_run)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
dataset_retrieval_run = LangSmithRunModel(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
inputs=trace_info.inputs,
outputs={"documents": trace_info.documents},
run_type=LangSmithRunType.retriever,
extra={"metadata": trace_info.metadata},
tags=["dataset_retrieval"],
parent_run_id=trace_info.message_id,
start_time=trace_info.start_time or trace_info.message_data.created_at,
end_time=trace_info.end_time or trace_info.message_data.updated_at,
id=str(uuid.uuid4()),
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
trace_id=trace_info.trace_id,
dotted_order=None,
error="",
file_list=[],
)
self.add_run(dataset_retrieval_run)
def tool_trace(self, trace_info: ToolTraceInfo):
tool_run = LangSmithRunModel(
name=trace_info.tool_name,
inputs=trace_info.tool_inputs,
outputs=trace_info.tool_outputs,
run_type=LangSmithRunType.tool,
extra={
"metadata": trace_info.metadata,
},
tags=["tool", trace_info.tool_name],
parent_run_id=trace_info.message_id,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
file_list=[cast(str, trace_info.file_url)],
id=str(uuid.uuid4()),
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
trace_id=trace_info.trace_id,
dotted_order=None,
error=trace_info.error or "",
)
self.add_run(tool_run)
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
name_run = LangSmithRunModel(
name=TraceTaskName.GENERATE_NAME_TRACE,
inputs=trace_info.inputs,
outputs=trace_info.outputs,
run_type=LangSmithRunType.tool,
extra={"metadata": trace_info.metadata},
tags=["generate_name"],
start_time=trace_info.start_time or datetime.now(),
end_time=trace_info.end_time or datetime.now(),
id=str(uuid.uuid4()),
serialized=None,
events=[],
session_id=None,
session_name=None,
reference_example_id=None,
input_attachments={},
output_attachments={},
trace_id=trace_info.trace_id,
dotted_order=None,
error="",
file_list=[],
parent_run_id=None,
)
self.add_run(name_run)
def add_run(self, run_data: LangSmithRunModel):
data = run_data.model_dump()
if self.project_id:
data["session_id"] = self.project_id
elif self.project_name:
data["session_name"] = self.project_name
data = filter_none_values(data)
try:
self.langsmith_client.create_run(**data)
logger.debug("LangSmith Run created successfully.")
except Exception as e:
raise ValueError(f"LangSmith Failed to create run: {str(e)}")
def update_run(self, update_run_data: LangSmithRunUpdateModel):
data = update_run_data.model_dump()
data = filter_none_values(data)
try:
self.langsmith_client.update_run(**data)
logger.debug("LangSmith Run updated successfully.")
except Exception as e:
raise ValueError(f"LangSmith Failed to update run: {str(e)}")
def api_check(self):
try:
random_project_name = f"test_project_{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.langsmith_client.create_project(project_name=random_project_name)
self.langsmith_client.delete_project(project_name=random_project_name)
return True
except Exception as e:
logger.debug("LangSmith API check failed: %s", str(e))
raise ValueError(f"LangSmith API check failed: {str(e)}")
def get_project_url(self):
try:
run_data = RunBase(
id=uuid.uuid4(),
name="tool",
inputs={"input": "test"},
outputs={"output": "test"},
run_type=LangSmithRunType.tool,
start_time=datetime.now(),
)
project_url = self.langsmith_client.get_run_url(
run=run_data, project_id=self.project_id, project_name=self.project_name
)
return project_url.split("/r/")[0]
except Exception as e:
logger.debug("LangSmith get run url failed: %s", str(e))
raise ValueError(f"LangSmith get run url failed: {str(e)}")

View File

@@ -0,0 +1,549 @@
import json
import logging
import os
from datetime import datetime, timedelta
from typing import Any, cast
import mlflow
from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType
from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey
from mlflow.tracing.fluent import start_span_no_context, update_current_trace
from mlflow.tracing.provider import detach_span_from_context, set_span_in_context
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from core.workflow.enums import NodeType
from extensions.ext_database import db
from models import EndUser
from models.workflow import WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
def datetime_to_nanoseconds(dt: datetime | None) -> int | None:
"""Convert datetime to nanosecond timestamp for MLflow API"""
if dt is None:
return None
return int(dt.timestamp() * 1_000_000_000)
class MLflowDataTrace(BaseTraceInstance):
def __init__(self, config: MLflowConfig | DatabricksConfig):
super().__init__(config)
if isinstance(config, DatabricksConfig):
self._setup_databricks(config)
else:
self._setup_mlflow(config)
# Enable async logging to minimize performance overhead
os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] = "true"
def _setup_databricks(self, config: DatabricksConfig):
"""Setup connection to Databricks-managed MLflow instances"""
os.environ["DATABRICKS_HOST"] = config.host
if config.client_id and config.client_secret:
# OAuth: https://docs.databricks.com/aws/en/dev-tools/auth/oauth-m2m?language=Environment
os.environ["DATABRICKS_CLIENT_ID"] = config.client_id
os.environ["DATABRICKS_CLIENT_SECRET"] = config.client_secret
elif config.personal_access_token:
# PAT: https://docs.databricks.com/aws/en/dev-tools/auth/pat
os.environ["DATABRICKS_TOKEN"] = config.personal_access_token
else:
raise ValueError(
"Either Databricks token (PAT) or client id and secret (OAuth) must be provided"
"See https://docs.databricks.com/aws/en/dev-tools/auth/#what-authorization-option-should-i-choose "
"for more information about the authorization options."
)
mlflow.set_tracking_uri("databricks")
mlflow.set_experiment(experiment_id=config.experiment_id)
# Remove trailing slash from host
config.host = config.host.rstrip("/")
self._project_url = f"{config.host}/ml/experiments/{config.experiment_id}/traces"
def _setup_mlflow(self, config: MLflowConfig):
"""Setup connection to MLflow instances"""
mlflow.set_tracking_uri(config.tracking_uri)
mlflow.set_experiment(experiment_id=config.experiment_id)
# Simple auth if provided
if config.username and config.password:
os.environ["MLFLOW_TRACKING_USERNAME"] = config.username
os.environ["MLFLOW_TRACKING_PASSWORD"] = config.password
self._project_url = f"{config.tracking_uri}/#/experiments/{config.experiment_id}/traces"
def trace(self, trace_info: BaseTraceInfo):
"""Simple dispatch to trace methods"""
try:
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
elif isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
elif isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
elif isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
elif isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
except Exception:
logger.exception("[MLflow] Trace error")
raise
def workflow_trace(self, trace_info: WorkflowTraceInfo):
"""Create workflow span as root, with node spans as children"""
# fields with sys.xyz is added by Dify, they are duplicate to trace_info.metadata
raw_inputs = trace_info.workflow_run_inputs or {}
workflow_inputs = {k: v for k, v in raw_inputs.items() if not k.startswith("sys.")}
# Special inputs propagated by system
if trace_info.query:
workflow_inputs["query"] = trace_info.query
workflow_span = start_span_no_context(
name=TraceTaskName.WORKFLOW_TRACE.value,
span_type=SpanType.CHAIN,
inputs=workflow_inputs,
attributes=trace_info.metadata,
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
)
# Set reserved fields in trace-level metadata
trace_metadata = {}
if user_id := trace_info.metadata.get("user_id"):
trace_metadata[TraceMetadataKey.TRACE_USER] = user_id
if session_id := trace_info.conversation_id:
trace_metadata[TraceMetadataKey.TRACE_SESSION] = session_id
self._set_trace_metadata(workflow_span, trace_metadata)
try:
# Create child spans for workflow nodes
for node in self._get_workflow_nodes(trace_info.workflow_run_id):
inputs = None
attributes = {
"node_id": node.id,
"node_type": node.node_type,
"status": node.status,
"tenant_id": node.tenant_id,
"app_id": node.app_id,
"app_name": node.title,
}
if node.node_type in (NodeType.LLM, NodeType.QUESTION_CLASSIFIER):
inputs, llm_attributes = self._parse_llm_inputs_and_attributes(node)
attributes.update(llm_attributes)
elif node.node_type == NodeType.HTTP_REQUEST:
inputs = node.process_data # contains request URL
if not inputs:
inputs = json.loads(node.inputs) if node.inputs else {}
node_span = start_span_no_context(
name=node.title,
span_type=self._get_node_span_type(node.node_type),
parent_span=workflow_span,
inputs=inputs,
attributes=attributes,
start_time_ns=datetime_to_nanoseconds(node.created_at),
)
# Handle node errors
if node.status != "succeeded":
node_span.set_status(SpanStatusCode.ERROR)
node_span.add_event(
SpanEvent( # type: ignore[abstract]
name="exception",
attributes={
"exception.message": f"Node failed with status: {node.status}",
"exception.type": "Error",
"exception.stacktrace": f"Node failed with status: {node.status}",
},
)
)
# End node span
finished_at = node.created_at + timedelta(seconds=node.elapsed_time)
outputs = json.loads(node.outputs) if node.outputs else {}
if node.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
outputs = self._parse_knowledge_retrieval_outputs(outputs)
elif node.node_type == NodeType.LLM:
outputs = outputs.get("text", outputs)
node_span.end(
outputs=outputs,
end_time_ns=datetime_to_nanoseconds(finished_at),
)
# Handle workflow-level errors
if trace_info.error:
workflow_span.set_status(SpanStatusCode.ERROR)
workflow_span.add_event(
SpanEvent( # type: ignore[abstract]
name="exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
)
finally:
workflow_span.end(
outputs=trace_info.workflow_run_outputs,
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
)
def _parse_llm_inputs_and_attributes(self, node: WorkflowNodeExecutionModel) -> tuple[Any, dict]:
"""Parse LLM inputs and attributes from LLM workflow node"""
if node.process_data is None:
return {}, {}
try:
data = json.loads(node.process_data)
except (json.JSONDecodeError, TypeError):
return {}, {}
inputs = self._parse_prompts(data.get("prompts"))
attributes = {
"model_name": data.get("model_name"),
"model_provider": data.get("model_provider"),
"finish_reason": data.get("finish_reason"),
}
if hasattr(SpanAttributeKey, "MESSAGE_FORMAT"):
attributes[SpanAttributeKey.MESSAGE_FORMAT] = "dify"
if usage := data.get("usage"):
# Set reserved token usage attributes
attributes[SpanAttributeKey.CHAT_USAGE] = {
TokenUsageKey.INPUT_TOKENS: usage.get("prompt_tokens", 0),
TokenUsageKey.OUTPUT_TOKENS: usage.get("completion_tokens", 0),
TokenUsageKey.TOTAL_TOKENS: usage.get("total_tokens", 0),
}
# Store raw usage data as well as it includes more data like price
attributes["usage"] = usage
return inputs, attributes
def _parse_knowledge_retrieval_outputs(self, outputs: dict):
"""Parse KR outputs and attributes from KR workflow node"""
retrieved = outputs.get("result", [])
if not retrieved or not isinstance(retrieved, list):
return outputs
documents = []
for item in retrieved:
documents.append(Document(page_content=item.get("content", ""), metadata=item.get("metadata", {})))
return documents
def message_trace(self, trace_info: MessageTraceInfo):
"""Create span for CHATBOT message processing"""
if not trace_info.message_data:
return
file_list = cast(list[str], trace_info.file_list) or []
if message_file_data := trace_info.message_file_data:
base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
file_list.append(f"{base_url}/{message_file_data.url}")
span = start_span_no_context(
name=TraceTaskName.MESSAGE_TRACE.value,
span_type=SpanType.LLM,
inputs=self._parse_prompts(trace_info.inputs), # type: ignore[arg-type]
attributes={
"message_id": trace_info.message_id, # type: ignore[dict-item]
"model_provider": trace_info.message_data.model_provider,
"model_id": trace_info.message_data.model_id,
"conversation_mode": trace_info.conversation_mode,
"file_list": file_list, # type: ignore[dict-item]
"total_price": trace_info.message_data.total_price,
**trace_info.metadata,
},
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
)
if hasattr(SpanAttributeKey, "MESSAGE_FORMAT"):
span.set_attribute(SpanAttributeKey.MESSAGE_FORMAT, "dify")
# Set token usage
span.set_attribute(
SpanAttributeKey.CHAT_USAGE,
{
TokenUsageKey.INPUT_TOKENS: trace_info.message_tokens or 0,
TokenUsageKey.OUTPUT_TOKENS: trace_info.answer_tokens or 0,
TokenUsageKey.TOTAL_TOKENS: trace_info.total_tokens or 0,
},
)
# Set reserved fields in trace-level metadata
trace_metadata = {}
if user_id := self._get_message_user_id(trace_info.metadata):
trace_metadata[TraceMetadataKey.TRACE_USER] = user_id
if session_id := trace_info.metadata.get("conversation_id"):
trace_metadata[TraceMetadataKey.TRACE_SESSION] = session_id
self._set_trace_metadata(span, trace_metadata)
if trace_info.error:
span.set_status(SpanStatusCode.ERROR)
span.add_event(
SpanEvent( # type: ignore[abstract]
name="error",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
)
span.end(
outputs=trace_info.message_data.answer,
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
)
def _get_message_user_id(self, metadata: dict) -> str | None:
if (end_user_id := metadata.get("from_end_user_id")) and (
end_user_data := db.session.query(EndUser).where(EndUser.id == end_user_id).first()
):
return end_user_data.session_id
return metadata.get("from_account_id") # type: ignore[return-value]
def tool_trace(self, trace_info: ToolTraceInfo):
span = start_span_no_context(
name=trace_info.tool_name,
span_type=SpanType.TOOL,
inputs=trace_info.tool_inputs, # type: ignore[arg-type]
attributes={
"message_id": trace_info.message_id, # type: ignore[dict-item]
"metadata": trace_info.metadata, # type: ignore[dict-item]
"tool_config": trace_info.tool_config, # type: ignore[dict-item]
"tool_parameters": trace_info.tool_parameters, # type: ignore[dict-item]
},
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
)
# Handle tool errors
if trace_info.error:
span.set_status(SpanStatusCode.ERROR)
span.add_event(
SpanEvent( # type: ignore[abstract]
name="error",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
)
span.end(
outputs=trace_info.tool_outputs,
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
)
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
span = start_span_no_context(
name=TraceTaskName.MODERATION_TRACE.value,
span_type=SpanType.TOOL,
inputs=trace_info.inputs or {},
attributes={
"message_id": trace_info.message_id, # type: ignore[dict-item]
"metadata": trace_info.metadata, # type: ignore[dict-item]
},
start_time_ns=datetime_to_nanoseconds(start_time),
)
span.end(
outputs={
"action": trace_info.action,
"flagged": trace_info.flagged,
"preset_response": trace_info.preset_response,
},
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
span = start_span_no_context(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
span_type=SpanType.RETRIEVER,
inputs=trace_info.inputs,
attributes={
"message_id": trace_info.message_id, # type: ignore[dict-item]
"metadata": trace_info.metadata, # type: ignore[dict-item]
},
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
)
span.end(outputs={"documents": trace_info.documents}, end_time_ns=datetime_to_nanoseconds(trace_info.end_time))
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
span = start_span_no_context(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
span_type=SpanType.TOOL,
inputs=trace_info.inputs,
attributes={
"message_id": trace_info.message_id, # type: ignore[dict-item]
"model_provider": trace_info.model_provider, # type: ignore[dict-item]
"model_id": trace_info.model_id, # type: ignore[dict-item]
"total_tokens": trace_info.total_tokens or 0, # type: ignore[dict-item]
},
start_time_ns=datetime_to_nanoseconds(start_time),
)
if trace_info.error:
span.set_status(SpanStatusCode.ERROR)
span.add_event(
SpanEvent( # type: ignore[abstract]
name="error",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
)
span.end(outputs=trace_info.suggested_question, end_time_ns=datetime_to_nanoseconds(end_time))
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
span = start_span_no_context(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
span_type=SpanType.CHAIN,
inputs=trace_info.inputs,
attributes={"message_id": trace_info.message_id}, # type: ignore[dict-item]
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
)
span.end(outputs=trace_info.outputs, end_time_ns=datetime_to_nanoseconds(trace_info.end_time))
def _get_workflow_nodes(self, workflow_run_id: str):
"""Helper method to get workflow nodes"""
workflow_nodes = (
db.session.query(
WorkflowNodeExecutionModel.id,
WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id,
WorkflowNodeExecutionModel.title,
WorkflowNodeExecutionModel.node_type,
WorkflowNodeExecutionModel.status,
WorkflowNodeExecutionModel.inputs,
WorkflowNodeExecutionModel.outputs,
WorkflowNodeExecutionModel.created_at,
WorkflowNodeExecutionModel.elapsed_time,
WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata,
)
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.order_by(WorkflowNodeExecutionModel.created_at)
.all()
)
return workflow_nodes
def _get_node_span_type(self, node_type: str) -> str:
"""Map Dify node types to MLflow span types"""
node_type_mapping = {
NodeType.LLM: SpanType.LLM,
NodeType.QUESTION_CLASSIFIER: SpanType.LLM,
NodeType.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER,
NodeType.TOOL: SpanType.TOOL,
NodeType.CODE: SpanType.TOOL,
NodeType.HTTP_REQUEST: SpanType.TOOL,
NodeType.AGENT: SpanType.AGENT,
}
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]
def _set_trace_metadata(self, span: Span, metadata: dict):
token = None
try:
# NB: Set span in context such that we can use update_current_trace() API
token = set_span_in_context(span)
update_current_trace(metadata=metadata)
finally:
if token:
detach_span_from_context(token)
def _parse_prompts(self, prompts):
"""Postprocess prompts format to be standard chat messages"""
if isinstance(prompts, str):
return prompts
elif isinstance(prompts, dict):
return self._parse_single_message(prompts)
elif isinstance(prompts, list):
messages = [self._parse_single_message(item) for item in prompts]
messages = self._resolve_tool_call_ids(messages)
return messages
return prompts # Fallback to original format
def _parse_single_message(self, item: dict):
"""Postprocess single message format to be standard chat message"""
role = item.get("role", "user")
msg = {"role": role, "content": item.get("text", "")}
if (
(tool_calls := item.get("tool_calls"))
# Tool message does not contain tool calls normally
and role != "tool"
):
msg["tool_calls"] = tool_calls
if files := item.get("files"):
msg["files"] = files
return msg
def _resolve_tool_call_ids(self, messages: list[dict]):
"""
The tool call message from Dify does not contain tool call ids, which is not
ideal for debugging. This method resolves the tool call ids by matching the
tool call name and parameters with the tool instruction messages.
"""
tool_call_ids = []
for msg in messages:
if tool_calls := msg.get("tool_calls"):
tool_call_ids = [t["id"] for t in tool_calls]
if msg["role"] == "tool":
# Get the tool call id in the order of the tool call messages
# assuming Dify runs tools sequentially
if tool_call_ids:
msg["tool_call_id"] = tool_call_ids.pop(0)
return messages
def api_check(self):
"""Simple connection test"""
try:
mlflow.search_experiments(max_results=1)
return True
except Exception as e:
raise ValueError(f"MLflow connection failed: {str(e)}")
def get_project_url(self):
return self._project_url

View File

View File

@@ -0,0 +1,462 @@
import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import cast
from opik import Opik, Trace
from opik.id_helpers import uuid4_to_uuid7
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import OpikConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
def wrap_dict(key_name, data):
"""Make sure that the input data is a dict"""
if not isinstance(data, dict):
return {key_name: data}
return data
def wrap_metadata(metadata, **kwargs):
"""Add common metatada to all Traces and Spans"""
metadata["created_from"] = "dify"
metadata.update(kwargs)
return metadata
def prepare_opik_uuid(user_datetime: datetime | None, user_uuid: str | None):
"""Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most
messages and objects. The type-hints of BaseTraceInfo indicates that
objects start_time and message_id could be null which means we cannot map
it to a UUIDv7. Given that we have no way to identify that object
uniquely, generate a new random one UUIDv7 in that case.
"""
if user_datetime is None:
user_datetime = datetime.now()
if user_uuid is None:
user_uuid = str(uuid.uuid4())
return uuid4_to_uuid7(user_datetime, user_uuid)
class OpikDataTrace(BaseTraceInstance):
def __init__(
self,
opik_config: OpikConfig,
):
super().__init__(opik_config)
self.opik_client = Opik(
project_name=opik_config.project,
workspace=opik_config.workspace,
host=opik_config.url,
api_key=opik_config.api_key,
)
self.project = opik_config.project
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
dify_trace_id = trace_info.trace_id or trace_info.workflow_run_id
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
workflow_metadata = wrap_metadata(
trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
)
root_span_id = None
if trace_info.message_id:
dify_trace_id = trace_info.trace_id or trace_info.message_id
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
trace_data = {
"id": opik_trace_id,
"name": TraceTaskName.MESSAGE_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"thread_id": trace_info.conversation_id,
"tags": ["message", "workflow"],
"project_name": self.project,
}
self.add_trace(trace_data)
root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
span_data = {
"id": root_span_id,
"parent_span_id": None,
"trace_id": opik_trace_id,
"name": TraceTaskName.WORKFLOW_TRACE,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
"tags": ["workflow"],
"project_name": self.project,
}
self.add_span(span_data)
else:
trace_data = {
"id": opik_trace_id,
"name": TraceTaskName.MESSAGE_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"thread_id": trace_info.conversation_id,
"tags": ["workflow"],
"project_name": self.project,
}
self.add_trace(trace_data)
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
)
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id
tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}
outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata or {}
metadata = {str(k): v for k, v in execution_metadata.items()}
metadata.update(
{
"workflow_run_id": trace_info.workflow_run_id,
"node_execution_id": node_execution_id,
"tenant_id": tenant_id,
"app_id": app_id,
"app_name": node_name,
"node_type": node_type,
"status": status,
}
)
process_data = node_execution.process_data or {}
provider = None
model = None
total_tokens = 0
completion_tokens = 0
prompt_tokens = 0
if process_data and process_data.get("model_mode") == "chat":
run_type = "llm"
provider = process_data.get("model_provider", None)
model = process_data.get("model_name", "")
metadata.update(
{
"ls_provider": provider,
"ls_model_name": model,
}
)
try:
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
total_tokens = usage_data.get("total_tokens", 0)
prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
else:
run_type = "tool"
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
if not total_tokens:
total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
span_data = {
"trace_id": opik_trace_id,
"id": prepare_opik_uuid(created_at, node_execution_id),
"parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id),
"name": node_name,
"type": run_type,
"start_time": created_at,
"end_time": finished_at,
"metadata": wrap_metadata(metadata),
"input": wrap_dict("input", inputs),
"output": wrap_dict("output", outputs),
"tags": ["node_execution"],
"project_name": self.project,
"usage": {
"total_tokens": total_tokens,
"completion_tokens": completion_tokens,
"prompt_tokens": prompt_tokens,
},
"model": model,
"provider": provider,
}
self.add_span(span_data)
def message_trace(self, trace_info: MessageTraceInfo):
# get message file data
file_list = cast(list[str], trace_info.file_list) or []
message_file_data: MessageFile | None = trace_info.message_file_data
if message_file_data is not None:
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
message_data = trace_info.message_data
if message_data is None:
return
metadata = trace_info.metadata
dify_trace_id = trace_info.trace_id or trace_info.message_id
user_id = message_data.from_account_id
metadata["user_id"] = user_id
metadata["file_list"] = file_list
if message_data.from_end_user_id:
end_user_data: EndUser | None = (
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id
metadata["end_user_id"] = end_user_id
trace_data = {
"id": prepare_opik_uuid(trace_info.start_time, dify_trace_id),
"name": TraceTaskName.MESSAGE_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(metadata),
"input": trace_info.inputs,
"output": message_data.answer,
"thread_id": message_data.conversation_id,
"tags": ["message", str(trace_info.conversation_mode)],
"project_name": self.project,
}
trace = self.add_trace(trace_data)
span_data = {
"trace_id": trace.id,
"name": "llm",
"type": "llm",
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(metadata),
"input": {"input": trace_info.inputs},
"output": {"output": message_data.answer},
"tags": ["llm", str(trace_info.conversation_mode)],
"usage": {
"completion_tokens": trace_info.answer_tokens,
"prompt_tokens": trace_info.message_tokens,
"total_tokens": trace_info.total_tokens,
},
"project_name": self.project,
}
self.add_span(span_data)
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.MODERATION_TRACE,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": {
"action": trace_info.action,
"flagged": trace_info.flagged,
"preset_response": trace_info.preset_response,
"inputs": trace_info.inputs,
},
"tags": ["moderation"],
}
self.add_span(span_data)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
start_time = trace_info.start_time or message_data.created_at
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or message_data.updated_at,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": wrap_dict("output", trace_info.suggested_question),
"tags": ["suggested_question"],
}
self.add_span(span_data)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": {"documents": trace_info.documents},
"tags": ["dataset_retrieval"],
}
self.add_span(span_data)
def tool_trace(self, trace_info: ToolTraceInfo):
span_data = {
"trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id),
"name": trace_info.tool_name,
"type": "tool",
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.tool_inputs),
"output": wrap_dict("output", trace_info.tool_outputs),
"tags": ["tool", trace_info.tool_name],
}
self.add_span(span_data)
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
trace_data = {
"id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.GENERATE_NAME_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),
"input": trace_info.inputs,
"output": trace_info.outputs,
"thread_id": trace_info.conversation_id,
"tags": ["generate_name"],
"project_name": self.project,
}
trace = self.add_trace(trace_data)
span_data = {
"trace_id": trace.id,
"name": TraceTaskName.GENERATE_NAME_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": wrap_dict("output", trace_info.outputs),
"tags": ["generate_name"],
}
self.add_span(span_data)
def add_trace(self, opik_trace_data: dict) -> Trace:
try:
trace = self.opik_client.trace(**opik_trace_data)
logger.debug("Opik Trace created successfully")
return trace
except Exception as e:
raise ValueError(f"Opik Failed to create trace: {str(e)}")
def add_span(self, opik_span_data: dict):
try:
self.opik_client.span(**opik_span_data)
logger.debug("Opik Span created successfully")
except Exception as e:
raise ValueError(f"Opik Failed to create span: {str(e)}")
def api_check(self):
try:
self.opik_client.auth_check()
return True
except Exception as e:
logger.info("Opik API check failed: %s", str(e), exc_info=True)
raise ValueError(f"Opik API check failed: {str(e)}")
def get_project_url(self):
try:
return self.opik_client.get_project_url(project_name=self.project)
except Exception as e:
logger.info("Opik get run url failed: %s", str(e), exc_info=True)
raise ValueError(f"Opik get run url failed: {str(e)}")

View File

@@ -0,0 +1,997 @@
import collections
import json
import logging
import os
import queue
import threading
import time
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import UUID, uuid4
from cachetools import LRUCache
from flask import current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
OPS_FILE_PATH,
TracingProviderEnum,
)
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
TaskData,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog
from repositories.factory import DifyAPIRepositoryFactory
from tasks.ops_trace_task import process_trace_tasks
if TYPE_CHECKING:
from core.workflow.entities import WorkflowExecution
logger = logging.getLogger(__name__)
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
def __getitem__(self, provider: str) -> dict[str, Any]:
match provider:
case TracingProviderEnum.LANGFUSE:
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
return {
"config_class": LangfuseConfig,
"secret_keys": ["public_key", "secret_key"],
"other_keys": ["host", "project_key"],
"trace_instance": LangFuseDataTrace,
}
case TracingProviderEnum.LANGSMITH:
from core.ops.entities.config_entity import LangSmithConfig
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
return {
"config_class": LangSmithConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": LangSmithDataTrace,
}
case TracingProviderEnum.OPIK:
from core.ops.entities.config_entity import OpikConfig
from core.ops.opik_trace.opik_trace import OpikDataTrace
return {
"config_class": OpikConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "url", "workspace"],
"trace_instance": OpikDataTrace,
}
case TracingProviderEnum.WEAVE:
from core.ops.entities.config_entity import WeaveConfig
from core.ops.weave_trace.weave_trace import WeaveDataTrace
return {
"config_class": WeaveConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "entity", "endpoint", "host"],
"trace_instance": WeaveDataTrace,
}
case TracingProviderEnum.ARIZE:
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
from core.ops.entities.config_entity import ArizeConfig
return {
"config_class": ArizeConfig,
"secret_keys": ["api_key", "space_id"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.PHOENIX:
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
from core.ops.entities.config_entity import PhoenixConfig
return {
"config_class": PhoenixConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.ALIYUN:
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
from core.ops.entities.config_entity import AliyunConfig
return {
"config_class": AliyunConfig,
"secret_keys": ["license_key"],
"other_keys": ["endpoint", "app_name"],
"trace_instance": AliyunDataTrace,
}
case TracingProviderEnum.MLFLOW:
from core.ops.entities.config_entity import MLflowConfig
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
return {
"config_class": MLflowConfig,
"secret_keys": ["password"],
"other_keys": ["tracking_uri", "experiment_id", "username"],
"trace_instance": MLflowDataTrace,
}
case TracingProviderEnum.DATABRICKS:
from core.ops.entities.config_entity import DatabricksConfig
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
return {
"config_class": DatabricksConfig,
"secret_keys": ["personal_access_token", "client_secret"],
"other_keys": ["host", "client_id", "experiment_id"],
"trace_instance": MLflowDataTrace,
}
case TracingProviderEnum.TENCENT:
from core.ops.entities.config_entity import TencentConfig
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
return {
"config_class": TencentConfig,
"secret_keys": ["token"],
"other_keys": ["endpoint", "service_name"],
"trace_instance": TencentDataTrace,
}
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")
provider_config_map = OpsTraceProviderConfigMap()
class OpsTraceManager:
ops_trace_instances_cache: LRUCache = LRUCache(maxsize=128)
decrypted_configs_cache: LRUCache = LRUCache(maxsize=128)
_decryption_cache_lock = threading.RLock()
@classmethod
def encrypt_tracing_config(
cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
):
"""
Encrypt tracing config.
:param tenant_id: tenant id
:param tracing_provider: tracing provider
:param tracing_config: tracing config dictionary to be encrypted
:param current_trace_config: current tracing configuration for keeping existing values
:return: encrypted tracing configuration
"""
# Get the configuration class and the keys that require encryption
config_class, secret_keys, other_keys = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["secret_keys"],
provider_config_map[tracing_provider]["other_keys"],
)
new_config: dict[str, Any] = {}
# Encrypt necessary keys
for key in secret_keys:
if key in tracing_config:
if "*" in tracing_config[key]:
# If the key contains '*', retain the original value from the current config
if current_trace_config:
new_config[key] = current_trace_config.get(key, tracing_config[key])
else:
new_config[key] = tracing_config[key]
else:
# Otherwise, encrypt the key
new_config[key] = encrypt_token(tenant_id, tracing_config[key])
for key in other_keys:
new_config[key] = tracing_config.get(key, "")
# Create a new instance of the config class with the new configuration
encrypted_config = config_class(**new_config)
return encrypted_config.model_dump()
@classmethod
def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
"""
Decrypt tracing config
:param tenant_id: tenant id
:param tracing_provider: tracing provider
:param tracing_config: tracing config
:return:
"""
config_json = json.dumps(tracing_config, sort_keys=True)
decrypted_config_key = (
tenant_id,
tracing_provider,
config_json,
)
# First check without lock for performance
cached_config = cls.decrypted_configs_cache.get(decrypted_config_key)
if cached_config is not None:
return dict(cached_config)
with cls._decryption_cache_lock:
# Second check (double-checked locking) to prevent race conditions
cached_config = cls.decrypted_configs_cache.get(decrypted_config_key)
if cached_config is not None:
return dict(cached_config)
config_class, secret_keys, other_keys = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["secret_keys"],
provider_config_map[tracing_provider]["other_keys"],
)
new_config: dict[str, Any] = {}
keys_to_decrypt = [key for key in secret_keys if key in tracing_config]
if keys_to_decrypt:
decrypted_values = batch_decrypt_token(tenant_id, [tracing_config[key] for key in keys_to_decrypt])
new_config.update(zip(keys_to_decrypt, decrypted_values))
for key in other_keys:
new_config[key] = tracing_config.get(key, "")
decrypted_config = config_class(**new_config).model_dump()
cls.decrypted_configs_cache[decrypted_config_key] = decrypted_config
return dict(decrypted_config)
@classmethod
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
"""
Decrypt tracing config
:param tracing_provider: tracing provider
:param decrypt_tracing_config: tracing config
:return:
"""
config_class, secret_keys, other_keys = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["secret_keys"],
provider_config_map[tracing_provider]["other_keys"],
)
new_config: dict[str, Any] = {}
for key in secret_keys:
if key in decrypt_tracing_config:
new_config[key] = obfuscated_token(decrypt_tracing_config[key])
for key in other_keys:
new_config[key] = decrypt_tracing_config.get(key, "")
return config_class(**new_config).model_dump()
@classmethod
def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str):
"""
Get decrypted tracing config
:param app_id: app id
:param tracing_provider: tracing provider
:return:
"""
trace_config_data: TraceAppConfig | None = (
db.session.query(TraceAppConfig)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
if not trace_config_data:
return None
# decrypt_token
stmt = select(App).where(App.id == app_id)
app = db.session.scalar(stmt)
if not app:
raise ValueError("App not found")
tenant_id = app.tenant_id
if trace_config_data.tracing_config is None:
raise ValueError("Tracing config cannot be None.")
decrypt_tracing_config = cls.decrypt_tracing_config(
tenant_id, tracing_provider, trace_config_data.tracing_config
)
return decrypt_tracing_config
@classmethod
def get_ops_trace_instance(
cls,
app_id: Union[UUID, str] | None = None,
):
"""
Get ops trace through model config
:param app_id: app_id
:return:
"""
if isinstance(app_id, UUID):
app_id = str(app_id)
if app_id is None:
return None
app: App | None = db.session.query(App).where(App.id == app_id).first()
if app is None:
return None
app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
if app_ops_trace_config is None:
return None
if not app_ops_trace_config.get("enabled"):
return None
tracing_provider = app_ops_trace_config.get("tracing_provider")
if tracing_provider is None:
return None
try:
provider_config_map[tracing_provider]
except KeyError:
return None
# decrypt_token
decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
if not decrypt_trace_config:
return None
trace_instance, config_class = (
provider_config_map[tracing_provider]["trace_instance"],
provider_config_map[tracing_provider]["config_class"],
)
decrypt_trace_config_key = json.dumps(decrypt_trace_config, sort_keys=True)
tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key)
if tracing_instance is None:
# create new tracing_instance and update the cache if it absent
tracing_instance = trace_instance(config_class(**decrypt_trace_config))
cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance
logger.info("new tracing_instance for app_id: %s", app_id)
return tracing_instance
@classmethod
def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None
message_stmt = select(Message).where(Message.id == message_id)
message_data = db.session.scalar(message_stmt)
if not message_data:
return None
conversation_id = message_data.conversation_id
conversation_stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation_data = db.session.scalar(conversation_stmt)
if not conversation_data:
return None
if conversation_data.app_model_config_id:
config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id)
app_model_config = db.session.scalar(config_stmt)
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
app_model_config = conversation_data.override_model_configs
return app_model_config
@classmethod
def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str):
"""
Update app tracing config
:param app_id: app id
:param enabled: enabled
:param tracing_provider: tracing provider
:return:
"""
# auth check
try:
if enabled or tracing_provider is not None:
provider_config_map[tracing_provider]
except KeyError:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
app_config: App | None = db.session.query(App).where(App.id == app_id).first()
if not app_config:
raise ValueError("App not found")
app_config.tracing = json.dumps(
{
"enabled": enabled,
"tracing_provider": tracing_provider,
}
)
db.session.commit()
@classmethod
def get_app_tracing_config(cls, app_id: str):
"""
Get app tracing config
:param app_id: app id
:return:
"""
app: App | None = db.session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError("App not found")
if not app.tracing:
return {"enabled": False, "tracing_provider": None}
app_trace_config = json.loads(app.tracing)
return app_trace_config
@staticmethod
def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
"""
Check trace config is effective
:param tracing_config: tracing config
:param tracing_provider: tracing provider
:return:
"""
config_type, trace_instance = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["trace_instance"],
)
tracing_config = config_type(**tracing_config)
return trace_instance(tracing_config).api_check()
@staticmethod
def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
"""
get trace config is project key
:param tracing_config: tracing config
:param tracing_provider: tracing provider
:return:
"""
config_type, trace_instance = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["trace_instance"],
)
tracing_config = config_type(**tracing_config)
return trace_instance(tracing_config).get_project_key()
@staticmethod
def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
"""
get trace config is project key
:param tracing_config: tracing config
:param tracing_provider: tracing provider
:return:
"""
config_type, trace_instance = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["trace_instance"],
)
tracing_config = config_type(**tracing_config)
return trace_instance(tracing_config).get_project_url()
class TraceTask:
_workflow_run_repo = None
_repo_lock = threading.Lock()
@classmethod
def _get_workflow_run_repo(cls):
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return cls._workflow_run_repo
def __init__(
self,
trace_type: Any,
message_id: str | None = None,
workflow_execution: Optional["WorkflowExecution"] = None,
conversation_id: str | None = None,
user_id: str | None = None,
timer: Any | None = None,
**kwargs,
):
self.trace_type = trace_type
self.message_id = message_id
self.workflow_run_id = workflow_execution.id_ if workflow_execution else None
self.conversation_id = conversation_id
self.user_id = user_id
self.timer = timer
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.app_id = None
self.trace_id = None
self.kwargs = kwargs
external_trace_id = kwargs.get("external_trace_id")
if external_trace_id:
self.trace_id = external_trace_id
def execute(self):
return self.preprocess()
def preprocess(self):
preprocess_map = {
TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
message_id=self.message_id, timer=self.timer, **self.kwargs
),
TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
message_id=self.message_id, timer=self.timer, **self.kwargs
),
TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
message_id=self.message_id, timer=self.timer, **self.kwargs
),
TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(
message_id=self.message_id, timer=self.timer, **self.kwargs
),
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
),
}
return preprocess_map.get(self.trace_type, lambda: None)()
# process methods for different trace types
def conversation_trace(self, **kwargs):
return kwargs
def workflow_trace(
self,
*,
workflow_run_id: str | None,
conversation_id: str | None,
user_id: str | None,
):
if not workflow_run_id:
return {}
workflow_run_repo = self._get_workflow_run_repo()
workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(run_id=workflow_run_id)
if not workflow_run:
raise ValueError("Workflow run not found")
workflow_id = workflow_run.workflow_id
tenant_id = workflow_run.tenant_id
workflow_run_id = workflow_run.id
workflow_run_elapsed_time = workflow_run.elapsed_time
workflow_run_status = workflow_run.status
workflow_run_inputs = workflow_run.inputs_dict
workflow_run_outputs = workflow_run.outputs_dict
workflow_run_version = workflow_run.version
error = workflow_run.error or ""
total_tokens = workflow_run.total_tokens
file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
with Session(db.engine) as session:
# get workflow_app_log_id
workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
WorkflowAppLog.tenant_id == tenant_id,
WorkflowAppLog.app_id == workflow_run.app_id,
WorkflowAppLog.workflow_run_id == workflow_run.id,
)
workflow_app_log_id = session.scalar(workflow_app_log_data_stmt)
# get message_id
message_id = None
if conversation_id:
message_data_stmt = select(Message.id).where(
Message.conversation_id == conversation_id,
Message.workflow_run_id == workflow_run_id,
)
message_id = session.scalar(message_data_stmt)
metadata = {
"workflow_id": workflow_id,
"conversation_id": conversation_id,
"workflow_run_id": workflow_run_id,
"tenant_id": tenant_id,
"elapsed_time": workflow_run_elapsed_time,
"status": workflow_run_status,
"version": workflow_run_version,
"total_tokens": total_tokens,
"file_list": file_list,
"triggered_from": workflow_run.triggered_from,
"user_id": user_id,
"app_id": workflow_run.app_id,
}
workflow_trace_info = WorkflowTraceInfo(
trace_id=self.trace_id,
workflow_data=workflow_run.to_dict(),
conversation_id=conversation_id,
workflow_id=workflow_id,
tenant_id=tenant_id,
workflow_run_id=workflow_run_id,
workflow_run_elapsed_time=workflow_run_elapsed_time,
workflow_run_status=workflow_run_status,
workflow_run_inputs=workflow_run_inputs,
workflow_run_outputs=workflow_run_outputs,
workflow_run_version=workflow_run_version,
error=error,
total_tokens=total_tokens,
file_list=file_list,
query=query,
metadata=metadata,
workflow_app_log_id=workflow_app_log_id,
message_id=message_id,
start_time=workflow_run.created_at,
end_time=workflow_run.finished_at,
)
return workflow_trace_info
def message_trace(self, message_id: str | None):
if not message_id:
return {}
message_data = get_message_data(message_id)
if not message_data:
return {}
conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
conversation_mode = db.session.scalars(conversation_mode_stmt).all()
if not conversation_mode or len(conversation_mode) == 0:
return {}
conversation_mode = conversation_mode[0]
created_at = message_data.created_at
inputs = message_data.message
# get message file data
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
file_list = []
if message_file_data and message_file_data.url is not None:
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
streaming_metrics = self._extract_streaming_metrics(message_data)
metadata = {
"conversation_id": message_data.conversation_id,
"ls_provider": message_data.model_provider,
"ls_model_name": message_data.model_id,
"status": message_data.status,
"from_end_user_id": message_data.from_end_user_id,
"from_account_id": message_data.from_account_id,
"agent_based": message_data.agent_based,
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
"message_id": message_id,
}
message_tokens = message_data.message_tokens
message_trace_info = MessageTraceInfo(
trace_id=self.trace_id,
message_id=message_id,
message_data=message_data.to_dict(),
conversation_model=conversation_mode,
message_tokens=message_tokens,
answer_tokens=message_data.answer_tokens,
total_tokens=message_tokens + message_data.answer_tokens,
error=message_data.error or "",
inputs=inputs,
outputs=message_data.answer,
file_list=file_list,
start_time=created_at,
end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
metadata=metadata,
message_file_data=message_file_data,
conversation_mode=conversation_mode,
gen_ai_server_time_to_first_token=streaming_metrics.get("gen_ai_server_time_to_first_token"),
llm_streaming_time_to_generate=streaming_metrics.get("llm_streaming_time_to_generate"),
is_streaming_request=streaming_metrics.get("is_streaming_request", False),
)
return message_trace_info
def moderation_trace(self, message_id, timer, **kwargs):
moderation_result = kwargs.get("moderation_result")
if not moderation_result:
return {}
inputs = kwargs.get("inputs")
message_data = get_message_data(message_id)
if not message_data:
return {}
metadata = {
"message_id": message_id,
"action": moderation_result.action,
"preset_response": moderation_result.preset_response,
"query": moderation_result.query,
}
# get workflow_app_log_id
workflow_app_log_id = None
if message_data.workflow_run_id:
workflow_app_log_data = (
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
)
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
moderation_trace_info = ModerationTraceInfo(
trace_id=self.trace_id,
message_id=workflow_app_log_id or message_id,
inputs=inputs,
message_data=message_data.to_dict(),
flagged=moderation_result.flagged,
action=moderation_result.action,
preset_response=moderation_result.preset_response,
query=moderation_result.query,
start_time=timer.get("start"),
end_time=timer.get("end"),
metadata=metadata,
)
return moderation_trace_info
def suggested_question_trace(self, message_id, timer, **kwargs):
suggested_question = kwargs.get("suggested_question", [])
message_data = get_message_data(message_id)
if not message_data:
return {}
metadata = {
"message_id": message_id,
"ls_provider": message_data.model_provider,
"ls_model_name": message_data.model_id,
"status": message_data.status,
"from_end_user_id": message_data.from_end_user_id,
"from_account_id": message_data.from_account_id,
"agent_based": message_data.agent_based,
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
}
# get workflow_app_log_id
workflow_app_log_id = None
if message_data.workflow_run_id:
workflow_app_log_data = (
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
)
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
suggested_question_trace_info = SuggestedQuestionTraceInfo(
trace_id=self.trace_id,
message_id=workflow_app_log_id or message_id,
message_data=message_data.to_dict(),
inputs=message_data.message,
outputs=message_data.answer,
start_time=timer.get("start"),
end_time=timer.get("end"),
metadata=metadata,
total_tokens=message_data.message_tokens + message_data.answer_tokens,
status=message_data.status,
error=message_data.error,
from_account_id=message_data.from_account_id,
agent_based=message_data.agent_based,
from_source=message_data.from_source,
model_provider=message_data.model_provider,
model_id=message_data.model_id,
suggested_question=suggested_question,
level=message_data.status,
status_message=message_data.error,
)
return suggested_question_trace_info
def dataset_retrieval_trace(self, message_id, timer, **kwargs):
documents = kwargs.get("documents")
message_data = get_message_data(message_id)
if not message_data:
return {}
metadata = {
"message_id": message_id,
"ls_provider": message_data.model_provider,
"ls_model_name": message_data.model_id,
"status": message_data.status,
"from_end_user_id": message_data.from_end_user_id,
"from_account_id": message_data.from_account_id,
"agent_based": message_data.agent_based,
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
}
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
trace_id=self.trace_id,
message_id=message_id,
inputs=message_data.query or message_data.inputs,
documents=[doc.model_dump() for doc in documents] if documents else [],
start_time=timer.get("start"),
end_time=timer.get("end"),
metadata=metadata,
message_data=message_data.to_dict(),
error=kwargs.get("error"),
)
return dataset_retrieval_trace_info
def tool_trace(self, message_id, timer, **kwargs):
tool_name = kwargs.get("tool_name", "")
tool_inputs = kwargs.get("tool_inputs", {})
tool_outputs = kwargs.get("tool_outputs", {})
message_data = get_message_data(message_id)
if not message_data:
return {}
tool_config = {}
time_cost = 0
error = None
tool_parameters = {}
created_time = message_data.created_at
end_time = message_data.updated_at
agent_thoughts = message_data.agent_thoughts
for agent_thought in agent_thoughts:
if tool_name in agent_thought.tools:
created_time = agent_thought.created_at
tool_meta_data = agent_thought.tool_meta.get(tool_name, {})
tool_config = tool_meta_data.get("tool_config", {})
time_cost = tool_meta_data.get("time_cost", 0)
end_time = created_time + timedelta(seconds=time_cost)
error = tool_meta_data.get("error", "")
tool_parameters = tool_meta_data.get("tool_parameters", {})
metadata = {
"message_id": message_id,
"tool_name": tool_name,
"tool_inputs": tool_inputs,
"tool_outputs": tool_outputs,
"tool_config": tool_config,
"time_cost": time_cost,
"error": error,
"tool_parameters": tool_parameters,
}
file_url = ""
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
if message_file_data:
message_file_id = message_file_data.id if message_file_data else None
type = message_file_data.type
created_by_role = message_file_data.created_by_role
created_user_id = message_file_data.created_by
file_url = f"{self.file_base_url}/{message_file_data.url}"
metadata.update(
{
"message_file_id": message_file_id,
"created_by_role": created_by_role,
"created_user_id": created_user_id,
"type": type,
}
)
tool_trace_info = ToolTraceInfo(
trace_id=self.trace_id,
message_id=message_id,
message_data=message_data.to_dict(),
tool_name=tool_name,
start_time=timer.get("start") if timer else created_time,
end_time=timer.get("end") if timer else end_time,
tool_inputs=tool_inputs,
tool_outputs=tool_outputs,
metadata=metadata,
message_file_data=message_file_data,
error=error,
inputs=message_data.message,
outputs=message_data.answer,
tool_config=tool_config,
time_cost=time_cost,
tool_parameters=tool_parameters,
file_url=file_url,
)
return tool_trace_info
def generate_name_trace(self, conversation_id, timer, **kwargs):
generate_conversation_name = kwargs.get("generate_conversation_name")
inputs = kwargs.get("inputs")
tenant_id = kwargs.get("tenant_id")
if not tenant_id:
return {}
start_time = timer.get("start")
end_time = timer.get("end")
metadata = {
"conversation_id": conversation_id,
"tenant_id": tenant_id,
}
generate_name_trace_info = GenerateNameTraceInfo(
trace_id=self.trace_id,
conversation_id=conversation_id,
inputs=inputs,
outputs=generate_conversation_name,
start_time=start_time,
end_time=end_time,
metadata=metadata,
tenant_id=tenant_id,
)
return generate_name_trace_info
def _extract_streaming_metrics(self, message_data) -> dict:
if not message_data.message_metadata:
return {}
try:
metadata = json.loads(message_data.message_metadata)
usage = metadata.get("usage", {})
time_to_first_token = usage.get("time_to_first_token")
time_to_generate = usage.get("time_to_generate")
return {
"gen_ai_server_time_to_first_token": time_to_first_token,
"llm_streaming_time_to_generate": time_to_generate,
"is_streaming_request": time_to_first_token is not None,
}
except (json.JSONDecodeError, AttributeError):
return {}
trace_manager_timer: threading.Timer | None = None
trace_manager_queue: queue.Queue = queue.Queue()
trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
class TraceQueueManager:
def __init__(self, app_id=None, user_id=None):
global trace_manager_timer
self.app_id = app_id
self.user_id = user_id
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
self.flask_app = current_app._get_current_object() # type: ignore
if trace_manager_timer is None:
self.start_timer()
def add_trace_task(self, trace_task: TraceTask):
global trace_manager_timer, trace_manager_queue
try:
if self.trace_instance:
trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task)
except Exception:
logger.exception("Error adding trace task, trace_type %s", trace_task.trace_type)
finally:
self.start_timer()
def collect_tasks(self):
global trace_manager_queue
tasks: list[TraceTask] = []
while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
task = trace_manager_queue.get_nowait()
tasks.append(task)
trace_manager_queue.task_done()
return tasks
def run(self):
try:
tasks = self.collect_tasks()
if tasks:
self.send_to_celery(tasks)
except Exception:
logger.exception("Error processing trace tasks")
def start_timer(self):
global trace_manager_timer
if trace_manager_timer is None or not trace_manager_timer.is_alive():
trace_manager_timer = threading.Timer(trace_manager_interval, self.run)
trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
trace_manager_timer.daemon = False
trace_manager_timer.start()
def send_to_celery(self, tasks: list[TraceTask]):
with self.flask_app.app_context():
for task in tasks:
if task.app_id is None:
continue
file_id = uuid4().hex
trace_info = task.execute()
task_data = TaskData(
app_id=task.app_id,
trace_info_type=type(trace_info).__name__,
trace_info=trace_info.model_dump() if trace_info else None,
)
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
file_info = {
"file_id": file_id,
"app_id": task.app_id,
}
process_trace_tasks.delay(file_info) # type: ignore

View File

@@ -0,0 +1,565 @@
"""
Tencent APM Trace Client - handles network operations, metrics, and API communication
"""
from __future__ import annotations
import importlib
import json
import logging
import os
import socket
from typing import TYPE_CHECKING
from urllib.parse import urlparse
try:
from importlib.metadata import version
except ImportError:
from importlib_metadata import version # type: ignore[import-not-found]
if TYPE_CHECKING:
from opentelemetry.metrics import Meter
from opentelemetry.metrics._internal.instrument import Histogram
from opentelemetry.sdk.metrics.export import MetricReader
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import SpanKind
from opentelemetry.util.types import AttributeValue
from configs import dify_config
from .entities.semconv import (
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
GEN_AI_STREAMING_TIME_TO_GENERATE,
GEN_AI_TOKEN_USAGE,
GEN_AI_TRACE_DURATION,
LLM_OPERATION_DURATION,
)
from .entities.tencent_trace_entity import SpanData
logger = logging.getLogger(__name__)
def _get_opentelemetry_sdk_version() -> str:
"""Get OpenTelemetry SDK version dynamically."""
try:
return version("opentelemetry-sdk")
except Exception:
logger.debug("Failed to get opentelemetry-sdk version, using default")
return "1.27.0" # fallback version
class TencentTraceClient:
"""Tencent APM trace client using OpenTelemetry OTLP exporter"""
def __init__(
self,
service_name: str,
endpoint: str,
token: str,
max_queue_size: int = 1000,
schedule_delay_sec: int = 5,
max_export_batch_size: int = 50,
metrics_export_interval_sec: int = 10,
):
self.endpoint = endpoint
self.token = token
self.service_name = service_name
self.metrics_export_interval_sec = metrics_export_interval_sec
self.resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
ResourceAttributes.TELEMETRY_SDK_LANGUAGE: "python",
ResourceAttributes.TELEMETRY_SDK_NAME: "opentelemetry",
ResourceAttributes.TELEMETRY_SDK_VERSION: _get_opentelemetry_sdk_version(),
}
)
# Prepare gRPC endpoint/metadata
grpc_endpoint, insecure, _, _ = self._resolve_grpc_target(endpoint)
headers = (("authorization", f"Bearer {token}"),)
self.exporter = OTLPSpanExporter(
endpoint=grpc_endpoint,
headers=headers,
insecure=insecure,
timeout=30,
)
self.tracer_provider = TracerProvider(resource=self.resource)
self.span_processor = BatchSpanProcessor(
span_exporter=self.exporter,
max_queue_size=max_queue_size,
schedule_delay_millis=schedule_delay_sec * 1000,
max_export_batch_size=max_export_batch_size,
)
self.tracer_provider.add_span_processor(self.span_processor)
# use dify api version as tracer version
self.tracer = self.tracer_provider.get_tracer("dify-sdk", dify_config.project.version)
# Store span contexts for parent-child relationships
self.span_contexts: dict[int, trace_api.SpanContext] = {}
self.meter: Meter | None = None
self.meter_provider: MeterProvider | None = None
self.hist_llm_duration: Histogram | None = None
self.hist_token_usage: Histogram | None = None
self.hist_time_to_first_token: Histogram | None = None
self.hist_time_to_generate: Histogram | None = None
self.hist_trace_duration: Histogram | None = None
self.metric_reader: MetricReader | None = None
# Metrics exporter and instruments
try:
from opentelemetry.sdk.metrics import Histogram, MeterProvider
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
protocol = os.getenv("OTEL_EXPORTER_OTLP_PROTOCOL", "").strip().lower()
use_http_protobuf = protocol in {"http/protobuf", "http-protobuf"}
use_http_json = protocol in {"http/json", "http-json"}
# Tencent APM works best with delta aggregation temporality
preferred_temporality: dict[type, AggregationTemporality] = {Histogram: AggregationTemporality.DELTA}
def _create_metric_exporter(exporter_cls, **kwargs):
"""Create metric exporter with preferred_temporality support"""
try:
return exporter_cls(**kwargs, preferred_temporality=preferred_temporality)
except Exception:
return exporter_cls(**kwargs)
metric_reader = None
if use_http_json:
exporter_cls = None
for mod_path in (
"opentelemetry.exporter.otlp.http.json.metric_exporter",
"opentelemetry.exporter.otlp.json.metric_exporter",
):
try:
mod = importlib.import_module(mod_path)
exporter_cls = getattr(mod, "OTLPMetricExporter", None)
if exporter_cls:
break
except Exception:
continue
if exporter_cls is not None:
metric_exporter = _create_metric_exporter(
exporter_cls,
endpoint=endpoint,
headers={"authorization": f"Bearer {token}"},
)
else:
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as HttpMetricExporter,
)
metric_exporter = _create_metric_exporter(
HttpMetricExporter,
endpoint=endpoint,
headers={"authorization": f"Bearer {token}"},
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
)
elif use_http_protobuf:
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as HttpMetricExporter,
)
metric_exporter = _create_metric_exporter(
HttpMetricExporter,
endpoint=endpoint,
headers={"authorization": f"Bearer {token}"},
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
)
else:
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import (
OTLPMetricExporter as GrpcMetricExporter,
)
m_grpc_endpoint, m_insecure, _, _ = self._resolve_grpc_target(endpoint)
metric_exporter = _create_metric_exporter(
GrpcMetricExporter,
endpoint=m_grpc_endpoint,
headers={"authorization": f"Bearer {token}"},
insecure=m_insecure,
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
)
if metric_reader is not None:
# Use instance-level MeterProvider instead of global to support config changes
# without worker restart. Each TencentTraceClient manages its own MeterProvider.
provider = MeterProvider(resource=self.resource, metric_readers=[metric_reader])
self.meter_provider = provider
self.meter = provider.get_meter("dify-sdk", dify_config.project.version)
# LLM operation duration histogram
self.hist_llm_duration = self.meter.create_histogram(
name=LLM_OPERATION_DURATION,
unit="s",
description="LLM operation duration (seconds)",
)
# Token usage histogram with exponential buckets
self.hist_token_usage = self.meter.create_histogram(
name=GEN_AI_TOKEN_USAGE,
unit="token",
description="Number of tokens used in prompt and completions",
)
# Time to first token histogram
self.hist_time_to_first_token = self.meter.create_histogram(
name=GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
unit="s",
description="Time to first token for streaming LLM responses (seconds)",
)
# Time to generate histogram
self.hist_time_to_generate = self.meter.create_histogram(
name=GEN_AI_STREAMING_TIME_TO_GENERATE,
unit="s",
description="Total time to generate streaming LLM responses (seconds)",
)
# Trace duration histogram
self.hist_trace_duration = self.meter.create_histogram(
name=GEN_AI_TRACE_DURATION,
unit="s",
description="End-to-end GenAI trace duration (seconds)",
)
self.metric_reader = metric_reader
else:
self.meter = None
self.meter_provider = None
self.hist_llm_duration = None
self.hist_token_usage = None
self.hist_time_to_first_token = None
self.hist_time_to_generate = None
self.hist_trace_duration = None
self.metric_reader = None
except Exception:
logger.exception("[Tencent APM] Metrics initialization failed; metrics disabled")
self.meter = None
self.meter_provider = None
self.hist_llm_duration = None
self.hist_token_usage = None
self.hist_time_to_first_token = None
self.hist_time_to_generate = None
self.hist_trace_duration = None
self.metric_reader = None
def add_span(self, span_data: SpanData) -> None:
"""Create and export span using OpenTelemetry Tracer API"""
try:
self._create_and_export_span(span_data)
logger.debug("[Tencent APM] Created span: %s", span_data.name)
except Exception:
logger.exception("[Tencent APM] Failed to create span: %s", span_data.name)
# Metrics recording API
def record_llm_duration(self, latency_seconds: float, attributes: dict[str, str] | None = None) -> None:
"""Record LLM operation duration histogram in seconds."""
try:
if not hasattr(self, "hist_llm_duration") or self.hist_llm_duration is None:
return
attrs: dict[str, str] = {}
if attributes:
for k, v in attributes.items():
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
LLM_OPERATION_DURATION,
latency_seconds,
json.dumps(attrs, ensure_ascii=False),
)
self.hist_llm_duration.record(latency_seconds, attrs) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record LLM duration", exc_info=True)
def record_token_usage(
self,
token_count: int,
token_type: str,
operation_name: str,
request_model: str,
response_model: str,
server_address: str,
provider: str,
) -> None:
"""Record token usage histogram.
Args:
token_count: Number of tokens used
token_type: "input" or "output"
operation_name: Operation name (e.g., "chat")
request_model: Model used in request
response_model: Model used in response
server_address: Server address
provider: Model provider name
"""
try:
if not hasattr(self, "hist_token_usage") or self.hist_token_usage is None:
return
attributes = {
"gen_ai.operation.name": operation_name,
"gen_ai.request.model": request_model,
"gen_ai.response.model": response_model,
"gen_ai.system": provider,
"gen_ai.token.type": token_type,
"server.address": server_address,
}
logger.info(
"[Tencent Metrics] Metric: %s | Value: %d | Attributes: %s",
GEN_AI_TOKEN_USAGE,
token_count,
json.dumps(attributes, ensure_ascii=False),
)
self.hist_token_usage.record(token_count, attributes) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record token usage", exc_info=True)
def record_time_to_first_token(
self, ttft_seconds: float, provider: str, model: str, operation_name: str = "chat"
) -> None:
"""Record time to first token histogram.
Args:
ttft_seconds: Time to first token in seconds
provider: Model provider name
model: Model name
operation_name: Operation name (default: "chat")
"""
try:
if not hasattr(self, "hist_time_to_first_token") or self.hist_time_to_first_token is None:
return
attributes = {
"gen_ai.operation.name": operation_name,
"gen_ai.system": provider,
"gen_ai.request.model": model,
"gen_ai.response.model": model,
"stream": "true",
}
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
ttft_seconds,
json.dumps(attributes, ensure_ascii=False),
)
self.hist_time_to_first_token.record(ttft_seconds, attributes) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record time to first token", exc_info=True)
def record_time_to_generate(
self, ttg_seconds: float, provider: str, model: str, operation_name: str = "chat"
) -> None:
"""Record time to generate histogram.
Args:
ttg_seconds: Time to generate in seconds
provider: Model provider name
model: Model name
operation_name: Operation name (default: "chat")
"""
try:
if not hasattr(self, "hist_time_to_generate") or self.hist_time_to_generate is None:
return
attributes = {
"gen_ai.operation.name": operation_name,
"gen_ai.system": provider,
"gen_ai.request.model": model,
"gen_ai.response.model": model,
"stream": "true",
}
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
GEN_AI_STREAMING_TIME_TO_GENERATE,
ttg_seconds,
json.dumps(attributes, ensure_ascii=False),
)
self.hist_time_to_generate.record(ttg_seconds, attributes) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record time to generate", exc_info=True)
def record_trace_duration(self, duration_seconds: float, attributes: dict[str, str] | None = None) -> None:
"""Record end-to-end trace duration histogram in seconds.
Args:
duration_seconds: Trace duration in seconds
attributes: Optional attributes (e.g., conversation_mode, app_id)
"""
try:
if not hasattr(self, "hist_trace_duration") or self.hist_trace_duration is None:
return
attrs: dict[str, str] = {}
if attributes:
for k, v in attributes.items():
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
GEN_AI_TRACE_DURATION,
duration_seconds,
json.dumps(attrs, ensure_ascii=False),
)
self.hist_trace_duration.record(duration_seconds, attrs) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record trace duration", exc_info=True)
def _create_and_export_span(self, span_data: SpanData) -> None:
"""Create span using OpenTelemetry Tracer API"""
try:
parent_context = None
if span_data.parent_span_id and span_data.parent_span_id in self.span_contexts:
parent_context = trace_api.set_span_in_context(
trace_api.NonRecordingSpan(self.span_contexts[span_data.parent_span_id])
)
span = self.tracer.start_span(
name=span_data.name,
context=parent_context,
kind=SpanKind.INTERNAL,
start_time=span_data.start_time,
)
self.span_contexts[span_data.span_id] = span.get_span_context()
if span_data.attributes:
attributes: dict[str, AttributeValue] = {}
for key, value in span_data.attributes.items():
if isinstance(value, (int, float, bool)):
attributes[key] = value
else:
attributes[key] = str(value)
span.set_attributes(attributes)
if span_data.events:
for event in span_data.events:
span.add_event(event.name, event.attributes, event.timestamp)
if span_data.status:
span.set_status(span_data.status)
# Manually end span; do not use context manager to avoid double-end warnings
span.end(end_time=span_data.end_time)
except Exception:
logger.exception("[Tencent APM] Error creating span: %s", span_data.name)
def api_check(self) -> bool:
"""Check API connectivity using socket connection test for gRPC endpoints"""
try:
# Resolve gRPC target consistently with exporters
_, _, host, port = self._resolve_grpc_target(self.endpoint)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5)
result = sock.connect_ex((host, port))
sock.close()
if result == 0:
logger.info("[Tencent APM] Endpoint %s:%s is accessible", host, port)
return True
else:
logger.warning("[Tencent APM] Endpoint %s:%s is not accessible", host, port)
if host in ["127.0.0.1", "localhost"]:
logger.info("[Tencent APM] Development environment detected, allowing config save")
return True
return False
except Exception:
logger.exception("[Tencent APM] API check failed")
if "127.0.0.1" in self.endpoint or "localhost" in self.endpoint:
return True
return False
def get_project_url(self) -> str:
"""Get project console URL"""
return "https://console.cloud.tencent.com/apm"
def shutdown(self) -> None:
"""Shutdown the client and export remaining spans"""
try:
if self.span_processor:
logger.info("[Tencent APM] Flushing remaining spans before shutdown")
_ = self.span_processor.force_flush()
self.span_processor.shutdown()
if self.tracer_provider:
self.tracer_provider.shutdown()
# Shutdown instance-level meter provider
if self.meter_provider is not None:
try:
self.meter_provider.shutdown() # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Error shutting down meter provider", exc_info=True)
if self.metric_reader is not None:
try:
self.metric_reader.shutdown() # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Error shutting down metric reader", exc_info=True)
except Exception:
logger.exception("[Tencent APM] Error during client shutdown")
@staticmethod
def _resolve_grpc_target(endpoint: str, default_port: int = 4317) -> tuple[str, bool, str, int]:
"""Normalize endpoint to gRPC target and security flag.
Returns:
(grpc_endpoint, insecure, host, port)
"""
try:
if endpoint.startswith(("http://", "https://")):
parsed = urlparse(endpoint)
host = parsed.hostname or "localhost"
port = parsed.port or default_port
insecure = parsed.scheme == "http"
return f"{host}:{port}", insecure, host, port
host = endpoint
port = default_port
if ":" in endpoint:
parts = endpoint.rsplit(":", 1)
host = parts[0] or "localhost"
try:
port = int(parts[1])
except Exception:
port = default_port
insecure = ("localhost" in host) or ("127.0.0.1" in host)
return f"{host}:{port}", insecure, host, port
except Exception:
host, port = "localhost", default_port
return f"{host}:{port}", True, host, port

View File

@@ -0,0 +1 @@
# Tencent trace entities module

View File

@@ -0,0 +1,89 @@
from enum import Enum
# public
GEN_AI_SESSION_ID = "gen_ai.session.id"
GEN_AI_USER_ID = "gen_ai.user.id"
GEN_AI_USER_NAME = "gen_ai.user.name"
GEN_AI_SPAN_KIND = "gen_ai.span.kind"
GEN_AI_FRAMEWORK = "gen_ai.framework"
GEN_AI_IS_ENTRY = "gen_ai.is_entry" # mark to count the LLM-related traces
# Chain
INPUT_VALUE = "gen_ai.entity.input"
OUTPUT_VALUE = "gen_ai.entity.output"
# Retriever
RETRIEVAL_QUERY = "retrieval.query"
RETRIEVAL_DOCUMENT = "retrieval.document"
# GENERATION
GEN_AI_MODEL_NAME = "gen_ai.response.model"
GEN_AI_PROVIDER = "gen_ai.provider.name"
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
GEN_AI_PROMPT = "gen_ai.prompt"
GEN_AI_COMPLETION = "gen_ai.completion"
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
# Streaming Span Attributes
GEN_AI_IS_STREAMING_REQUEST = "llm.is_streaming" # Same as OpenLLMetry semconv
# Tool
TOOL_NAME = "tool.name"
TOOL_DESCRIPTION = "tool.description"
TOOL_PARAMETERS = "tool.parameters"
# Instrumentation Library
INSTRUMENTATION_NAME = "dify-sdk"
INSTRUMENTATION_VERSION = "0.1.0"
INSTRUMENTATION_LANGUAGE = "python"
# Metrics
LLM_OPERATION_DURATION = "gen_ai.client.operation.duration"
GEN_AI_TOKEN_USAGE = "gen_ai.client.token.usage"
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN = "gen_ai.server.time_to_first_token"
GEN_AI_STREAMING_TIME_TO_GENERATE = "gen_ai.streaming.time_to_generate"
# The LLM trace duration which is exclusive to tencent apm
GEN_AI_TRACE_DURATION = "gen_ai.trace.duration"
# Token Usage Attributes
GEN_AI_OPERATION_NAME = "gen_ai.operation.name"
GEN_AI_REQUEST_MODEL = "gen_ai.request.model"
GEN_AI_RESPONSE_MODEL = "gen_ai.response.model"
GEN_AI_SYSTEM = "gen_ai.system"
GEN_AI_TOKEN_TYPE = "gen_ai.token.type"
SERVER_ADDRESS = "server.address"
class GenAISpanKind(Enum):
WORKFLOW = "WORKFLOW" # OpenLLMetry
RETRIEVER = "RETRIEVER" # RAG
GENERATION = "GENERATION" # Langfuse
TOOL = "TOOL" # OpenLLMetry
AGENT = "AGENT" # OpenLLMetry
TASK = "TASK" # OpenLLMetry

View File

@@ -0,0 +1,21 @@
from collections.abc import Sequence
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from pydantic import BaseModel, Field
class SpanData(BaseModel):
model_config = {"arbitrary_types_allowed": True}
trace_id: int = Field(..., description="The unique identifier for the trace.")
parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.")
span_id: int = Field(..., description="The unique identifier for this span.")
name: str = Field(..., description="The name of the span.")
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int = Field(..., description="The start time of the span in nanoseconds.")
end_time: int = Field(..., description="The end time of the span in nanoseconds.")

View File

@@ -0,0 +1,383 @@
"""
Tencent APM Span Builder - handles all span construction logic
"""
import json
import logging
from datetime import datetime
from opentelemetry.trace import Status, StatusCode
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
MessageTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
GEN_AI_IS_ENTRY,
GEN_AI_IS_STREAMING_REQUEST,
GEN_AI_MODEL_NAME,
GEN_AI_PROMPT,
GEN_AI_PROVIDER,
GEN_AI_RESPONSE_FINISH_REASON,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
GEN_AI_USAGE_INPUT_TOKENS,
GEN_AI_USAGE_OUTPUT_TOKENS,
GEN_AI_USAGE_TOTAL_TOKENS,
GEN_AI_USER_ID,
INPUT_VALUE,
OUTPUT_VALUE,
RETRIEVAL_DOCUMENT,
RETRIEVAL_QUERY,
TOOL_DESCRIPTION,
TOOL_NAME,
TOOL_PARAMETERS,
GenAISpanKind,
)
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
from core.ops.tencent_trace.utils import TencentTraceUtils
from core.rag.models.document import Document
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
logger = logging.getLogger(__name__)
class TencentSpanBuilder:
"""Builder class for constructing different types of spans"""
@staticmethod
def _get_time_nanoseconds(time_value: datetime | None) -> int:
"""Convert datetime to nanoseconds for span creation."""
return TencentTraceUtils.convert_datetime_to_nanoseconds(time_value)
@staticmethod
def build_workflow_spans(
trace_info: WorkflowTraceInfo, trace_id: int, user_id: str, links: list | None = None
) -> list[SpanData]:
"""Build workflow-related spans"""
spans = []
links = links or []
message_span_id = None
workflow_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "workflow")
if hasattr(trace_info, "metadata") and trace_info.metadata.get("conversation_id"):
message_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "message")
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
if message_span_id:
message_span = TencentSpanBuilder._build_message_span(
trace_info, trace_id, message_span_id, user_id, status, links
)
spans.append(message_span)
workflow_span = TencentSpanBuilder._build_workflow_span(
trace_info, trace_id, workflow_span_id, message_span_id, user_id, status, links
)
spans.append(workflow_span)
return spans
@staticmethod
def _build_message_span(
trace_info: WorkflowTraceInfo, trace_id: int, message_span_id: int, user_id: str, status: Status, links: list
) -> SpanData:
"""Build message span for chatflow"""
return SpanData(
trace_id=trace_id,
parent_span_id=None,
span_id=message_span_id,
name="message",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_IS_ENTRY: "true",
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""),
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
status=status,
links=links,
)
@staticmethod
def _build_workflow_span(
trace_info: WorkflowTraceInfo,
trace_id: int,
workflow_span_id: int,
message_span_id: int | None,
user_id: str,
status: Status,
links: list,
) -> SpanData:
"""Build workflow span"""
attributes = {
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
}
if message_span_id is None:
attributes[GEN_AI_IS_ENTRY] = "true"
return SpanData(
trace_id=trace_id,
parent_span_id=message_span_id,
span_id=workflow_span_id,
name="workflow",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes=attributes,
status=status,
links=links,
)
@staticmethod
def build_workflow_llm_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build LLM span for workflow nodes."""
process_data = node_execution.process_data or {}
outputs = node_execution.outputs or {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
attributes = {
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
GEN_AI_PROVIDER: process_data.get("model_provider", ""),
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
GEN_AI_COMPLETION: str(outputs.get("text", "")),
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
OUTPUT_VALUE: str(outputs.get("text", "")),
}
if usage_data.get("time_to_first_token") is not None:
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name="GENERATION",
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes=attributes,
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def build_message_span(
trace_info: MessageTraceInfo, trace_id: int, user_id: str, links: list | None = None
) -> SpanData:
"""Build message span."""
links = links or []
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
attributes = {
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_IS_ENTRY: "true",
INPUT_VALUE: str(trace_info.inputs or ""),
OUTPUT_VALUE: str(trace_info.outputs or ""),
}
if trace_info.is_streaming_request:
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
return SpanData(
trace_id=trace_id,
parent_span_id=None,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message"),
name="message",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes=attributes,
status=status,
links=links,
)
@staticmethod
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
"""Build tool span."""
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
return SpanData(
trace_id=trace_id,
parent_span_id=parent_span_id,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "tool"),
name=trace_info.tool_name,
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: trace_info.tool_name,
TOOL_DESCRIPTION: "",
TOOL_PARAMETERS: json.dumps(trace_info.tool_parameters, ensure_ascii=False),
INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.tool_outputs),
},
status=status,
)
@staticmethod
def build_retrieval_span(trace_info: DatasetRetrievalTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
"""Build dataset retrieval span."""
status = Status(StatusCode.OK)
if getattr(trace_info, "error", None):
status = Status(StatusCode.ERROR, trace_info.error) # type: ignore[arg-type]
documents_data = TencentSpanBuilder._extract_retrieval_documents(trace_info.documents)
return SpanData(
trace_id=trace_id,
parent_span_id=parent_span_id,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "retrieval"),
name="retrieval",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
RETRIEVAL_QUERY: str(trace_info.inputs or ""),
RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
INPUT_VALUE: str(trace_info.inputs or ""),
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
},
status=status,
)
@staticmethod
def _get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
"""Get workflow node execution status."""
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
return Status(StatusCode.OK)
elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
return Status(StatusCode.ERROR, str(node_execution.error))
return Status(StatusCode.UNSET)
@staticmethod
def build_workflow_retrieval_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build knowledge retrieval span for workflow nodes."""
input_value = ""
if node_execution.inputs:
input_value = str(node_execution.inputs.get("query", ""))
output_value = ""
if node_execution.outputs:
output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
RETRIEVAL_QUERY: input_value,
RETRIEVAL_DOCUMENT: output_value,
INPUT_VALUE: input_value,
OUTPUT_VALUE: output_value,
},
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def build_workflow_tool_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build tool span for workflow nodes."""
tool_des = {}
if node_execution.metadata:
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: node_execution.title,
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def build_workflow_task_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build generic task span for workflow nodes."""
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def _extract_retrieval_documents(documents: list[Document]):
"""Extract documents data for retrieval tracing."""
documents_data = []
for document in documents:
document_data = {
"content": document.page_content,
"metadata": {
"dataset_id": document.metadata.get("dataset_id"),
"doc_id": document.metadata.get("doc_id"),
"document_id": document.metadata.get("document_id"),
},
"score": document.metadata.get("score"),
}
documents_data.append(document_data)
return documents_data

View File

@@ -0,0 +1,520 @@
"""
Tencent APM tracing implementation with separated concerns
"""
import logging
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import TencentConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.client import TencentTraceClient
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
from core.ops.tencent_trace.span_builder import TencentSpanBuilder
from core.ops.tencent_trace.utils import TencentTraceUtils
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from core.workflow.nodes import NodeType
from extensions.ext_database import db
from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
class TencentDataTrace(BaseTraceInstance):
"""
Tencent APM trace implementation with single responsibility principle.
Acts as a coordinator that delegates specific tasks to specialized classes.
"""
def __init__(self, tencent_config: TencentConfig):
super().__init__(tencent_config)
self.trace_client = TencentTraceClient(
service_name=tencent_config.service_name,
endpoint=tencent_config.endpoint,
token=tencent_config.token,
metrics_export_interval_sec=5,
)
def trace(self, trace_info: BaseTraceInfo) -> None:
"""Main tracing entry point - coordinates different trace types."""
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
elif isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
elif isinstance(trace_info, ModerationTraceInfo):
pass
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
elif isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
elif isinstance(trace_info, GenerateNameTraceInfo):
pass
def api_check(self) -> bool:
return self.trace_client.api_check()
def get_project_url(self) -> str:
return self.trace_client.get_project_url()
def workflow_trace(self, trace_info: WorkflowTraceInfo) -> None:
"""Handle workflow tracing by coordinating data retrieval and span construction."""
try:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_info.workflow_run_id)
links = []
if trace_info.trace_id:
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
user_id = self._get_user_id(trace_info)
workflow_spans = TencentSpanBuilder.build_workflow_spans(trace_info, trace_id, str(user_id), links)
for span in workflow_spans:
self.trace_client.add_span(span)
self._process_workflow_nodes(trace_info, trace_id)
# Record trace duration for entry span
self._record_workflow_trace_duration(trace_info)
except Exception:
logger.exception("[Tencent APM] Failed to process workflow trace")
def message_trace(self, trace_info: MessageTraceInfo) -> None:
"""Handle message tracing."""
try:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_info.message_id)
user_id = self._get_user_id(trace_info)
links = []
if trace_info.trace_id:
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
self.trace_client.add_span(message_span)
self._record_message_llm_metrics(trace_info)
# Record trace duration for entry span
self._record_message_trace_duration(trace_info)
except Exception:
logger.exception("[Tencent APM] Failed to process message trace")
def tool_trace(self, trace_info: ToolTraceInfo) -> None:
"""Handle tool tracing."""
try:
parent_span_id = None
trace_root_id = None
if trace_info.message_id:
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
trace_root_id = trace_info.message_id
if parent_span_id and trace_root_id:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_root_id)
tool_span = TencentSpanBuilder.build_tool_span(trace_info, trace_id, parent_span_id)
self.trace_client.add_span(tool_span)
except Exception:
logger.exception("[Tencent APM] Failed to process tool trace")
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo) -> None:
"""Handle dataset retrieval tracing."""
try:
parent_span_id = None
trace_root_id = None
if trace_info.message_id:
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
trace_root_id = trace_info.message_id
if parent_span_id and trace_root_id:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_root_id)
retrieval_span = TencentSpanBuilder.build_retrieval_span(trace_info, trace_id, parent_span_id)
self.trace_client.add_span(retrieval_span)
except Exception:
logger.exception("[Tencent APM] Failed to process dataset retrieval trace")
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo) -> None:
"""Handle suggested question tracing"""
try:
logger.info("[Tencent APM] Processing suggested question trace")
except Exception:
logger.exception("[Tencent APM] Failed to process suggested question trace")
def _process_workflow_nodes(self, trace_info: WorkflowTraceInfo, trace_id: int) -> None:
"""Process workflow node executions."""
try:
workflow_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "workflow")
node_executions = self._get_workflow_node_executions(trace_info)
for node_execution in node_executions:
try:
node_span = self._build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
if node_span:
self.trace_client.add_span(node_span)
if node_execution.node_type == NodeType.LLM:
self._record_llm_metrics(node_execution)
except Exception:
logger.exception("[Tencent APM] Failed to process node execution: %s", node_execution.id)
except Exception:
logger.exception("[Tencent APM] Failed to process workflow nodes")
def _build_workflow_node_span(
self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
) -> SpanData | None:
"""Build span for different node types"""
try:
if node_execution.node_type == NodeType.LLM:
return TencentSpanBuilder.build_workflow_llm_span(
trace_id, workflow_span_id, trace_info, node_execution
)
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return TencentSpanBuilder.build_workflow_retrieval_span(
trace_id, workflow_span_id, trace_info, node_execution
)
elif node_execution.node_type == NodeType.TOOL:
return TencentSpanBuilder.build_workflow_tool_span(
trace_id, workflow_span_id, trace_info, node_execution
)
else:
# Handle all other node types as generic tasks
return TencentSpanBuilder.build_workflow_task_span(
trace_id, workflow_span_id, trace_info, node_execution
)
except Exception:
logger.debug(
"[Tencent APM] Error building span for node %s: %s",
node_execution.id,
node_execution.node_type,
exc_info=True,
)
return None
def _get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> list[WorkflowNodeExecution]:
"""Retrieve workflow node executions from database."""
try:
session_maker = sessionmaker(bind=db.engine)
with Session(db.engine, expire_on_commit=False) as session:
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator")
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account not found for app {app_id}")
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")
service_account.set_tenant_id(current_tenant.tenant_id)
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_maker,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
executions = repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id)
return list(executions)
except Exception:
logger.exception("[Tencent APM] Failed to get workflow node executions")
return []
def _get_user_id(self, trace_info: BaseTraceInfo) -> str:
"""Get user ID from trace info."""
try:
tenant_id = None
user_id = None
if isinstance(trace_info, (WorkflowTraceInfo, GenerateNameTraceInfo)):
tenant_id = trace_info.tenant_id
if hasattr(trace_info, "metadata") and trace_info.metadata:
user_id = trace_info.metadata.get("user_id")
if user_id and tenant_id:
stmt = (
select(Account.name)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.where(Account.id == user_id, TenantAccountJoin.tenant_id == tenant_id)
)
session_maker = sessionmaker(bind=db.engine)
with session_maker() as session:
account_name = session.scalar(stmt)
return account_name or str(user_id)
elif user_id:
return str(user_id)
return "anonymous"
except Exception:
logger.exception("[Tencent APM] Failed to get user ID")
return "unknown"
def _record_llm_metrics(self, node_execution: WorkflowNodeExecution) -> None:
"""Record LLM performance metrics"""
try:
process_data = node_execution.process_data or {}
outputs = node_execution.outputs or {}
usage = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
model_provider = process_data.get("model_provider", "unknown")
model_name = process_data.get("model_name", "unknown")
model_mode = process_data.get("model_mode", "chat")
# Record LLM duration
if hasattr(self.trace_client, "record_llm_duration"):
latency_s = float(usage.get("latency", 0.0))
if latency_s > 0:
# Determine if streaming from usage metrics
is_streaming = usage.get("time_to_first_token") is not None
attributes = {
"gen_ai.system": model_provider,
"gen_ai.response.model": model_name,
"gen_ai.operation.name": model_mode,
"stream": "true" if is_streaming else "false",
}
self.trace_client.record_llm_duration(latency_s, attributes)
# Record streaming metrics from usage
time_to_first_token = usage.get("time_to_first_token")
if time_to_first_token is not None and hasattr(self.trace_client, "record_time_to_first_token"):
ttft_seconds = float(time_to_first_token)
if ttft_seconds > 0:
self.trace_client.record_time_to_first_token(
ttft_seconds=ttft_seconds, provider=model_provider, model=model_name, operation_name=model_mode
)
time_to_generate = usage.get("time_to_generate")
if time_to_generate is not None and hasattr(self.trace_client, "record_time_to_generate"):
ttg_seconds = float(time_to_generate)
if ttg_seconds > 0:
self.trace_client.record_time_to_generate(
ttg_seconds=ttg_seconds, provider=model_provider, model=model_name, operation_name=model_mode
)
# Record token usage
if hasattr(self.trace_client, "record_token_usage"):
# Extract token counts
input_tokens = int(usage.get("prompt_tokens", 0))
output_tokens = int(usage.get("completion_tokens", 0))
if input_tokens > 0 or output_tokens > 0:
server_address = f"{model_provider}"
# Record input tokens
if input_tokens > 0:
self.trace_client.record_token_usage(
token_count=input_tokens,
token_type="input",
operation_name=model_mode,
request_model=model_name,
response_model=model_name,
server_address=server_address,
provider=model_provider,
)
# Record output tokens
if output_tokens > 0:
self.trace_client.record_token_usage(
token_count=output_tokens,
token_type="output",
operation_name=model_mode,
request_model=model_name,
response_model=model_name,
server_address=server_address,
provider=model_provider,
)
except Exception:
logger.debug("[Tencent APM] Failed to record LLM metrics")
def _record_message_llm_metrics(self, trace_info: MessageTraceInfo) -> None:
"""Record LLM metrics for message traces"""
try:
trace_metadata = trace_info.metadata or {}
message_data = trace_info.message_data or {}
provider_latency = 0.0
if isinstance(message_data, dict):
provider_latency = float(message_data.get("provider_response_latency", 0.0) or 0.0)
else:
provider_latency = float(getattr(message_data, "provider_response_latency", 0.0) or 0.0)
model_provider = trace_metadata.get("ls_provider") or (
message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
)
model_name = trace_metadata.get("ls_model_name") or (
message_data.get("model_id", "") if isinstance(message_data, dict) else ""
)
# Record LLM duration
if provider_latency > 0 and hasattr(self.trace_client, "record_llm_duration"):
is_streaming = trace_info.is_streaming_request
duration_attributes = {
"gen_ai.system": model_provider,
"gen_ai.response.model": model_name,
"gen_ai.operation.name": "chat", # Message traces are always chat
"stream": "true" if is_streaming else "false",
}
self.trace_client.record_llm_duration(provider_latency, duration_attributes)
# Record streaming metrics for message traces
if trace_info.is_streaming_request:
# Record time to first token
if trace_info.gen_ai_server_time_to_first_token is not None and hasattr(
self.trace_client, "record_time_to_first_token"
):
ttft_seconds = float(trace_info.gen_ai_server_time_to_first_token)
if ttft_seconds > 0:
self.trace_client.record_time_to_first_token(
ttft_seconds=ttft_seconds, provider=str(model_provider or ""), model=str(model_name or "")
)
# Record time to generate
if trace_info.llm_streaming_time_to_generate is not None and hasattr(
self.trace_client, "record_time_to_generate"
):
ttg_seconds = float(trace_info.llm_streaming_time_to_generate)
if ttg_seconds > 0:
self.trace_client.record_time_to_generate(
ttg_seconds=ttg_seconds, provider=str(model_provider or ""), model=str(model_name or "")
)
# Record token usage
if hasattr(self.trace_client, "record_token_usage"):
input_tokens = int(trace_info.message_tokens or 0)
output_tokens = int(trace_info.answer_tokens or 0)
if input_tokens > 0:
self.trace_client.record_token_usage(
token_count=input_tokens,
token_type="input",
operation_name="chat",
request_model=str(model_name or ""),
response_model=str(model_name or ""),
server_address=str(model_provider or ""),
provider=str(model_provider or ""),
)
if output_tokens > 0:
self.trace_client.record_token_usage(
token_count=output_tokens,
token_type="output",
operation_name="chat",
request_model=str(model_name or ""),
response_model=str(model_name or ""),
server_address=str(model_provider or ""),
provider=str(model_provider or ""),
)
except Exception:
logger.debug("[Tencent APM] Failed to record message LLM metrics")
def _record_workflow_trace_duration(self, trace_info: WorkflowTraceInfo) -> None:
"""Record end-to-end workflow trace duration."""
try:
if not hasattr(self.trace_client, "record_trace_duration"):
return
# Calculate duration from start_time and end_time to match span duration
if trace_info.start_time and trace_info.end_time:
duration_s = (trace_info.end_time - trace_info.start_time).total_seconds()
else:
# Fallback to workflow_run_elapsed_time if timestamps not available
duration_s = float(trace_info.workflow_run_elapsed_time)
if duration_s > 0:
attributes = {
"conversation_mode": "workflow",
"workflow_status": trace_info.workflow_run_status,
}
# Add conversation_id if available
if trace_info.conversation_id:
attributes["has_conversation"] = "true"
else:
attributes["has_conversation"] = "false"
self.trace_client.record_trace_duration(duration_s, attributes)
except Exception:
logger.debug("[Tencent APM] Failed to record workflow trace duration")
def _record_message_trace_duration(self, trace_info: MessageTraceInfo) -> None:
"""Record end-to-end message trace duration."""
try:
if not hasattr(self.trace_client, "record_trace_duration"):
return
# Calculate duration from start_time and end_time
if trace_info.start_time and trace_info.end_time:
duration = (trace_info.end_time - trace_info.start_time).total_seconds()
if duration > 0:
attributes = {
"conversation_mode": trace_info.conversation_mode,
}
# Add streaming flag if available
if hasattr(trace_info, "is_streaming_request"):
attributes["stream"] = "true" if trace_info.is_streaming_request else "false"
self.trace_client.record_trace_duration(duration, attributes)
except Exception:
logger.debug("[Tencent APM] Failed to record message trace duration")
def __del__(self):
"""Ensure proper cleanup on garbage collection."""
try:
if hasattr(self, "trace_client"):
self.trace_client.shutdown()
except Exception:
pass

View File

@@ -0,0 +1,65 @@
"""
Utility functions for Tencent APM tracing
"""
import hashlib
import random
import uuid
from datetime import datetime
from typing import cast
from opentelemetry.trace import Link, SpanContext, TraceFlags
class TencentTraceUtils:
"""Utility class for common tracing operations."""
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000
@staticmethod
def convert_to_trace_id(uuid_v4: str | None) -> int:
try:
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
return cast(int, uuid_obj.int)
@staticmethod
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
try:
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
combined_key = f"{uuid_obj.hex}-{span_type}"
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
return int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
@staticmethod
def generate_span_id() -> int:
span_id = random.getrandbits(64)
while span_id == TencentTraceUtils.INVALID_SPAN_ID:
span_id = random.getrandbits(64)
return span_id
@staticmethod
def convert_datetime_to_nanoseconds(start_time: datetime | None) -> int:
if start_time is None:
start_time = datetime.now()
timestamp_in_seconds = start_time.timestamp()
return int(timestamp_in_seconds * 1e9)
@staticmethod
def create_link(trace_id_str: str) -> Link:
try:
trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else cast(int, uuid.UUID(trace_id_str).int)
except (ValueError, TypeError):
trace_id = cast(int, uuid.uuid4().int)
span_context = SpanContext(
trace_id=trace_id,
span_id=TencentTraceUtils.INVALID_SPAN_ID,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
return Link(span_context)

160
dify/api/core/ops/utils.py Normal file
View File

@@ -0,0 +1,160 @@
from contextlib import contextmanager
from datetime import datetime
from typing import Union
from urllib.parse import urlparse
from sqlalchemy import select
from extensions.ext_database import db
from models.model import Message
def filter_none_values(data: dict):
new_data = {}
for key, value in data.items():
if value is None:
continue
if isinstance(value, datetime):
new_data[key] = value.isoformat()
else:
new_data[key] = value
return new_data
def get_message_data(message_id: str):
return db.session.scalar(select(Message).where(Message.id == message_id))
@contextmanager
def measure_time():
timing_info = {"start": datetime.now(), "end": None}
try:
yield timing_info
finally:
timing_info["end"] = datetime.now()
def replace_text_with_content(data):
if isinstance(data, dict):
new_data = {}
for key, value in data.items():
if key == "text":
new_data["content"] = value
else:
new_data[key] = replace_text_with_content(value)
return new_data
elif isinstance(data, list):
return [replace_text_with_content(item) for item in data]
else:
return data
def generate_dotted_order(run_id: str, start_time: Union[str, datetime], parent_dotted_order: str | None = None) -> str:
"""
generate dotted_order for langsmith
"""
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
current_segment = f"{timestamp}{run_id}"
if parent_dotted_order is None:
return current_segment
return f"{parent_dotted_order}.{current_segment}"
def validate_url(url: str, default_url: str, allowed_schemes: tuple = ("https", "http")) -> str:
"""
Validate and normalize URL with proper error handling.
NOTE: This function does not retain the `path` component of the provided URL.
In most cases, it is recommended to use `validate_url_with_path` instead.
This function is deprecated and retained only for compatibility purposes.
New implementations should use `validate_url_with_path`.
Args:
url: The URL to validate
default_url: Default URL to use if input is None or empty
allowed_schemes: Tuple of allowed URL schemes (default: https, http)
Returns:
Normalized URL string
Raises:
ValueError: If URL format is invalid or scheme not allowed
"""
if not url or url.strip() == "":
return default_url
# Parse URL to validate format
parsed = urlparse(url)
# Check if scheme is allowed
if parsed.scheme not in allowed_schemes:
raise ValueError(f"URL scheme must be one of: {', '.join(allowed_schemes)}")
# Reconstruct URL with only scheme, netloc (removing path, query, fragment)
normalized_url = f"{parsed.scheme}://{parsed.netloc}"
return normalized_url
def validate_url_with_path(url: str, default_url: str, required_suffix: str | None = None) -> str:
"""
Validate URL that may include path components
Args:
url: The URL to validate
default_url: Default URL to use if input is None or empty
required_suffix: Optional suffix that URL must end with
Returns:
Validated URL string
Raises:
ValueError: If URL format is invalid or doesn't match required suffix
"""
if not url or url.strip() == "":
return default_url
# Parse URL to validate format
parsed = urlparse(url)
# Check if scheme is allowed
if parsed.scheme not in ("https", "http"):
raise ValueError("URL must start with https:// or http://")
# Check required suffix if specified
if required_suffix and not url.endswith(required_suffix):
raise ValueError(f"URL should end with {required_suffix}")
return url
def validate_project_name(project: str, default_name: str) -> str:
"""
Validate and normalize project name
Args:
project: Project name to validate
default_name: Default name to use if input is None or empty
Returns:
Normalized project name
"""
if not project or project.strip() == "":
return default_name
return project.strip()
def validate_integer_id(id_str: str) -> str:
"""
Validate and normalize integer ID
"""
id_str = id_str.strip()
if not id_str.isdigit():
raise ValueError("ID must be a valid integer")
return id_str

View File

@@ -0,0 +1,98 @@
from collections.abc import Mapping
from typing import Any, Union
from pydantic import BaseModel, Field, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.ops.utils import replace_text_with_content
class WeaveTokenUsage(BaseModel):
input_tokens: int | None = None
output_tokens: int | None = None
total_tokens: int | None = None
class WeaveMultiModel(BaseModel):
file_list: list[str] | None = Field(None, description="List of files")
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
id: str = Field(..., description="ID of the trace")
op: str = Field(..., description="Name of the operation")
inputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Inputs of the trace")
outputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Outputs of the trace")
attributes: Union[str, dict[str, Any], list, None] | None = Field(
None, description="Metadata and attributes associated with trace"
)
exception: str | None = Field(None, description="Exception message of the trace")
@field_validator("inputs", "outputs")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
values = info.data
if v == {} or v is None:
return v
usage_metadata = {
"input_tokens": values.get("input_tokens", 0),
"output_tokens": values.get("output_tokens", 0),
"total_tokens": values.get("total_tokens", 0),
}
file_list = values.get("file_list", [])
if isinstance(v, str):
if field_name == "inputs":
return {
"messages": {
"role": "user",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif field_name == "outputs":
return {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif isinstance(v, list):
data = {}
if len(v) > 0 and isinstance(v[0], dict):
# rename text to content
v = replace_text_with_content(data=v)
if field_name == "inputs":
data = {
"messages": [
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v
]
if isinstance(v, list)
else v,
}
elif field_name == "outputs":
data = {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
return data
else:
return {
"choices": {
"role": "ai" if field_name == "outputs" else "user",
"content": str(v),
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
if isinstance(v, dict):
v["usage_metadata"] = usage_metadata
v["file_list"] = file_list
return v
return v

View File

@@ -0,0 +1,523 @@
import logging
import os
import uuid
from datetime import UTC, datetime, timedelta
from typing import Any, cast
import wandb
import weave
from sqlalchemy.orm import sessionmaker
from weave.trace_server.trace_server_interface import (
CallEndReq,
CallStartReq,
EndedCallSchemaForInsert,
StartedCallSchemaForInsert,
SummaryInsertMap,
TraceStatus,
)
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import WeaveConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
class WeaveDataTrace(BaseTraceInstance):
def __init__(
self,
weave_config: WeaveConfig,
):
super().__init__(weave_config)
self.weave_api_key = weave_config.api_key
self.project_name = weave_config.project
self.entity = weave_config.entity
self.host = weave_config.host
# Login with API key first, including host if provided
if self.host:
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
else:
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
if not login_status:
logger.error("Failed to login to Weights & Biases with the provided API key")
raise ValueError("Weave login failed")
# Then initialize weave client
self.weave_client = weave.init(
project_name=(f"{self.entity}/{self.project_name}" if self.entity else self.project_name)
)
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.calls: dict[str, Any] = {}
self.project_id = f"{self.weave_client.entity}/{self.weave_client.project}"
def get_project_url(
self,
):
try:
project_identifier = f"{self.entity}/{self.project_name}" if self.entity else self.project_name
project_url = f"https://wandb.ai/{project_identifier}"
return project_url
except Exception as e:
logger.debug("Weave get run url failed: %s", str(e))
raise ValueError(f"Weave get run url failed: {str(e)}")
def trace(self, trace_info: BaseTraceInfo):
logger.debug("Trace info: %s", trace_info)
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
if trace_info.start_time is None:
trace_info.start_time = datetime.now()
if trace_info.message_id:
message_attributes = trace_info.metadata
message_attributes["workflow_app_log_id"] = trace_info.workflow_app_log_id
message_attributes["message_id"] = trace_info.message_id
message_attributes["workflow_run_id"] = trace_info.workflow_run_id
message_attributes["trace_id"] = trace_id
message_attributes["start_time"] = trace_info.start_time
message_attributes["end_time"] = trace_info.end_time
message_attributes["tags"] = ["message", "workflow"]
message_run = WeaveTraceModel(
id=trace_info.message_id,
op=str(TraceTaskName.MESSAGE_TRACE),
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
total_tokens=trace_info.total_tokens,
attributes=message_attributes,
exception=trace_info.error,
file_list=[],
)
self.start_call(message_run, parent_run_id=trace_info.workflow_run_id)
self.finish_call(message_run)
workflow_attributes = trace_info.metadata
workflow_attributes["workflow_run_id"] = trace_info.workflow_run_id
workflow_attributes["trace_id"] = trace_id
workflow_attributes["start_time"] = trace_info.start_time
workflow_attributes["end_time"] = trace_info.end_time
workflow_attributes["tags"] = ["dify_workflow"]
workflow_run = WeaveTraceModel(
file_list=trace_info.file_list,
total_tokens=trace_info.total_tokens,
id=trace_info.workflow_run_id,
op=str(TraceTaskName.WORKFLOW_TRACE),
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
attributes=workflow_attributes,
exception=trace_info.error,
)
self.start_call(workflow_run, parent_run_id=trace_info.message_id)
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
)
# rearrange workflow_node_executions by starting time
workflow_node_executions = sorted(workflow_node_executions, key=lambda x: x.created_at)
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id
tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}
outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata or {}
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
attributes = {str(k): v for k, v in execution_metadata.items()}
attributes.update(
{
"workflow_run_id": trace_info.workflow_run_id,
"node_execution_id": node_execution_id,
"tenant_id": tenant_id,
"app_id": app_id,
"app_name": node_name,
"node_type": node_type,
"status": status,
}
)
process_data = node_execution.process_data or {}
if process_data and process_data.get("model_mode") == "chat":
attributes.update(
{
"ls_provider": process_data.get("model_provider", ""),
"ls_model_name": process_data.get("model_name", ""),
}
)
attributes["tags"] = ["node_execution"]
attributes["start_time"] = created_at
attributes["end_time"] = finished_at
attributes["elapsed_time"] = elapsed_time
attributes["workflow_run_id"] = trace_info.workflow_run_id
attributes["trace_id"] = trace_id
node_run = WeaveTraceModel(
total_tokens=node_total_tokens,
op=node_type,
inputs=inputs,
outputs=outputs,
file_list=trace_info.file_list,
attributes=attributes,
id=node_execution_id,
exception=None,
)
self.start_call(node_run, parent_run_id=trace_info.workflow_run_id)
self.finish_call(node_run)
self.finish_call(workflow_run)
def message_trace(self, trace_info: MessageTraceInfo):
# get message file data
file_list = cast(list[str], trace_info.file_list) or []
message_file_data: MessageFile | None = trace_info.message_file_data
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
attributes = trace_info.metadata
message_data = trace_info.message_data
if message_data is None:
return
message_id = message_data.id
user_id = message_data.from_account_id
attributes["user_id"] = user_id
if message_data.from_end_user_id:
end_user_data: EndUser | None = (
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id
attributes["end_user_id"] = end_user_id
attributes["message_id"] = message_id
attributes["start_time"] = trace_info.start_time
attributes["end_time"] = trace_info.end_time
attributes["tags"] = ["message", str(trace_info.conversation_mode)]
trace_id = trace_info.trace_id or message_id
attributes["trace_id"] = trace_id
message_run = WeaveTraceModel(
id=trace_id,
op=str(TraceTaskName.MESSAGE_TRACE),
input_tokens=trace_info.message_tokens,
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
inputs=trace_info.inputs,
outputs=trace_info.outputs,
exception=trace_info.error,
file_list=file_list,
attributes=attributes,
)
self.start_call(message_run)
# create llm run parented to message run
llm_run = WeaveTraceModel(
id=str(uuid.uuid4()),
input_tokens=trace_info.message_tokens,
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
op="llm",
inputs=trace_info.inputs,
outputs=trace_info.outputs,
attributes=attributes,
file_list=[],
exception=None,
)
self.start_call(
llm_run,
parent_run_id=trace_id,
)
self.finish_call(llm_run)
self.finish_call(message_run)
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
attributes = trace_info.metadata
attributes["tags"] = ["moderation"]
attributes["message_id"] = trace_info.message_id
attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at
attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at
trace_id = trace_info.trace_id or trace_info.message_id
attributes["trace_id"] = trace_id
moderation_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.MODERATION_TRACE),
inputs=trace_info.inputs,
outputs={
"action": trace_info.action,
"flagged": trace_info.flagged,
"preset_response": trace_info.preset_response,
"inputs": trace_info.inputs,
},
attributes=attributes,
exception=getattr(trace_info, "error", None),
file_list=[],
)
self.start_call(moderation_run, parent_run_id=trace_id)
self.finish_call(moderation_run)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
attributes = trace_info.metadata
attributes["message_id"] = trace_info.message_id
attributes["tags"] = ["suggested_question"]
attributes["start_time"] = (trace_info.start_time or message_data.created_at,)
attributes["end_time"] = (trace_info.end_time or message_data.updated_at,)
trace_id = trace_info.trace_id or trace_info.message_id
attributes["trace_id"] = trace_id
suggested_question_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE),
inputs=trace_info.inputs,
outputs=trace_info.suggested_question,
attributes=attributes,
exception=trace_info.error,
file_list=[],
)
self.start_call(suggested_question_run, parent_run_id=trace_id)
self.finish_call(suggested_question_run)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
attributes = trace_info.metadata
attributes["message_id"] = trace_info.message_id
attributes["tags"] = ["dataset_retrieval"]
attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,)
attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,)
trace_id = trace_info.trace_id or trace_info.message_id
attributes["trace_id"] = trace_id
dataset_retrieval_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE),
inputs=trace_info.inputs,
outputs={"documents": trace_info.documents},
attributes=attributes,
exception=getattr(trace_info, "error", None),
file_list=[],
)
self.start_call(dataset_retrieval_run, parent_run_id=trace_id)
self.finish_call(dataset_retrieval_run)
def tool_trace(self, trace_info: ToolTraceInfo):
attributes = trace_info.metadata
attributes["tags"] = ["tool", trace_info.tool_name]
attributes["start_time"] = trace_info.start_time
attributes["end_time"] = trace_info.end_time
message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None)
message_id = message_id or None
trace_id = trace_info.trace_id or message_id
attributes["trace_id"] = trace_id
tool_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=trace_info.tool_name,
inputs=trace_info.tool_inputs,
outputs=trace_info.tool_outputs,
file_list=[cast(str, trace_info.file_url)] if trace_info.file_url else [],
attributes=attributes,
exception=trace_info.error,
)
self.start_call(tool_run, parent_run_id=trace_id)
self.finish_call(tool_run)
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
attributes = trace_info.metadata
attributes["tags"] = ["generate_name"]
attributes["start_time"] = trace_info.start_time
attributes["end_time"] = trace_info.end_time
name_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.GENERATE_NAME_TRACE),
inputs=trace_info.inputs,
outputs=trace_info.outputs,
attributes=attributes,
exception=getattr(trace_info, "error", None),
file_list=[],
)
self.start_call(name_run)
self.finish_call(name_run)
def api_check(self):
try:
if self.host:
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
else:
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
if not login_status:
raise ValueError("Weave login failed")
else:
logger.info("Weave login successful")
return True
except Exception as e:
logger.debug("Weave API check failed: %s", str(e))
raise ValueError(f"Weave API check failed: {str(e)}")
def _normalize_time(self, dt: datetime | None) -> datetime:
if dt is None:
return datetime.now(UTC)
if dt.tzinfo is None:
return dt.replace(tzinfo=UTC)
return dt
def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None):
inputs = run_data.inputs
if inputs is None:
inputs = {}
elif not isinstance(inputs, dict):
inputs = {"inputs": str(inputs)}
attributes = run_data.attributes
if attributes is None:
attributes = {}
elif not isinstance(attributes, dict):
attributes = {"attributes": str(attributes)}
start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
trace_id = attributes.get("trace_id") if isinstance(attributes, dict) else None
if trace_id is None:
trace_id = run_data.id
call_start_req = CallStartReq(
start=StartedCallSchemaForInsert(
project_id=self.project_id,
id=run_data.id,
op_name=str(run_data.op),
trace_id=trace_id,
parent_id=parent_run_id,
started_at=started_at,
attributes=attributes,
inputs=inputs,
wb_user_id=None,
)
)
self.weave_client.server.call_start(call_start_req)
self.calls[run_data.id] = {"trace_id": trace_id, "parent_id": parent_run_id}
def finish_call(self, run_data: WeaveTraceModel):
call_meta = self.calls.get(run_data.id)
if not call_meta:
raise ValueError(f"Call with id {run_data.id} not found")
attributes = run_data.attributes
if attributes is None:
attributes = {}
elif not isinstance(attributes, dict):
attributes = {"attributes": str(attributes)}
start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
end_time = attributes.get("end_time") if isinstance(attributes, dict) else None
started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
ended_at = self._normalize_time(end_time if isinstance(end_time, datetime) else None)
elapsed_ms = int((ended_at - started_at).total_seconds() * 1000)
if elapsed_ms < 0:
elapsed_ms = 0
status_counts = {
TraceStatus.SUCCESS: 0,
TraceStatus.ERROR: 0,
}
if run_data.exception:
status_counts[TraceStatus.ERROR] = 1
else:
status_counts[TraceStatus.SUCCESS] = 1
summary: dict[str, Any] = {
"status_counts": status_counts,
"weave": {"latency_ms": elapsed_ms},
}
exception_str = str(run_data.exception) if run_data.exception else None
call_end_req = CallEndReq(
end=EndedCallSchemaForInsert(
project_id=self.project_id,
id=run_data.id,
ended_at=ended_at,
exception=exception_str,
output=run_data.outputs,
summary=cast(SummaryInsertMap, summary),
)
)
self.weave_client.server.call_end(call_end_req)