Files
urbanLifeline/difyPlugin/pdf/tools/pdf_summary.py
2026-03-06 14:50:43 +08:00

210 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import re
from collections.abc import Generator
from typing import Any
import fitz
from dify_plugin import Tool
from dify_plugin.entities.model.llm import LLMModelConfig
from dify_plugin.entities.model.message import SystemPromptMessage, UserPromptMessage
from dify_plugin.entities.tool import ToolInvokeMessage
class PdfSummaryTool(Tool):
"""Fast PDF page summary tool.
Default behavior is optimized for throughput in large workflows:
- Extract plain text and lightweight table data only.
- Skip expensive image base64 and drawing path extraction.
- Skip LLM by default unless `use_llm=true` is explicitly passed.
"""
def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
file = tool_parameters.get("file")
if not file:
yield self.create_text_message("Error: file is required")
return
start_page = self._to_int(tool_parameters.get("pdf_start_page"), 0)
end_page = self._to_int(tool_parameters.get("pdf_end_page"), 0)
model_config = tool_parameters.get("model")
use_llm = self._to_bool(tool_parameters.get("use_llm"), False)
max_chars_per_page = self._to_int(tool_parameters.get("max_chars_per_page"), 6000)
max_chars_per_page = max(800, min(max_chars_per_page, 20000))
llm_prompt = tool_parameters.get(
"llm_prompt",
"请基于输入的PDF页面文本做简洁准确摘要输出中文要点。不要输出思考过程。",
)
pdf_bytes = file.blob
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
try:
num_pages = len(doc)
start_page = max(0, min(start_page, num_pages - 1))
end_page = max(start_page, min(end_page, num_pages - 1))
pages_data: list[dict[str, Any]] = []
for page_idx in range(start_page, end_page + 1):
page = doc[page_idx]
page_data = self._extract_page_fast(page, page_idx, max_chars_per_page)
pages_data.append(page_data)
result = {
"total_pages_extracted": len(pages_data),
"page_range": {"start": start_page, "end": end_page},
"pages": pages_data,
}
yield self.create_json_message(result)
# Fast local summary first (deterministic, no model latency)
local_text = self._build_local_summary(pages_data)
# Optional LLM refinement, explicitly enabled only
if use_llm and model_config:
refined = self._summarize_with_llm(local_text, llm_prompt, model_config)
final_text = refined if refined else local_text
else:
final_text = local_text
if final_text:
yield self.create_text_message(final_text)
finally:
doc.close()
def _extract_page_fast(self, page: fitz.Page, page_idx: int, max_chars_per_page: int) -> dict[str, Any]:
text = (page.get_text("text") or "").strip()
if len(text) > max_chars_per_page:
text = text[:max_chars_per_page] + "\n...[truncated]"
tables: list[dict[str, Any]] = []
try:
tabs = page.find_tables()
for tab_idx, tab in enumerate(tabs.tables[:3]):
cells = tab.extract() or []
tables.append(
{
"index": tab_idx,
"rows": tab.row_count,
"cols": tab.col_count,
"cells": cells[:10],
}
)
except Exception:
pass
return {
"page_number": page_idx,
"text": text,
"tables": tables,
"images": [],
"drawings_summary": [],
"text_blocks": [],
"width": float(page.rect.width),
"height": float(page.rect.height),
}
def _build_local_summary(self, pages_data: list[dict[str, Any]]) -> str:
"""Output actual page content as Markdown (text + tables).
No LLM needed downstream — the text is already usable Markdown.
"""
parts: list[str] = []
for page in pages_data:
text = (page.get("text") or "").strip()
tables = page.get("tables") or []
page_parts: list[str] = []
if text:
page_parts.append(text)
for tab in tables:
cells = tab.get("cells") or []
if len(cells) >= 2:
md = self._cells_to_md_table(cells)
if md:
page_parts.append(md)
if page_parts:
parts.append("\n\n".join(page_parts))
return "\n\n--- 分页 ---\n\n".join(parts)
@staticmethod
def _cells_to_md_table(cells: list) -> str:
if not cells:
return ""
header = cells[0]
ncols = len(header)
if ncols == 0:
return ""
clean = lambda c: str(c or "").replace("|", "\\|").replace("\n", " ").strip()
lines = [
"| " + " | ".join(clean(c) for c in header) + " |",
"| " + " | ".join("---" for _ in range(ncols)) + " |",
]
for row in cells[1:]:
padded = list(row) + [""] * max(0, ncols - len(row))
lines.append("| " + " | ".join(clean(c) for c in padded[:ncols]) + " |")
return "\n".join(lines)
def _summarize_with_llm(self, local_text: str, llm_prompt: str, model_config: dict[str, Any]) -> str:
response = self.session.model.llm.invoke(
model_config=LLMModelConfig(**model_config),
prompt_messages=[
SystemPromptMessage(content=llm_prompt),
UserPromptMessage(content=local_text),
],
stream=False,
)
llm_text = ""
if hasattr(response, "message") and response.message:
content = response.message.content
if isinstance(content, str):
llm_text = content
elif isinstance(content, list):
llm_text = "".join(
item.data if hasattr(item, "data") else str(item)
for item in content
)
return self._extract_visible_answer(llm_text)
@staticmethod
def _extract_visible_answer(text: str) -> str:
if not text:
return ""
box_match = re.search(r"<\|begin_of_box\|>([\s\S]*?)<\|end_of_box\|>", text)
if box_match:
text = box_match.group(1)
else:
text = re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE)
text = re.sub(r"<\|[^>]+\|>", "", text)
return text.strip()
@staticmethod
def _to_int(value: Any, default: int) -> int:
try:
if value is None or value == "":
return default
return int(value)
except Exception:
return default
@staticmethod
def _to_bool(value: Any, default: bool) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
s = str(value).strip().lower()
if s in {"1", "true", "yes", "on"}:
return True
if s in {"0", "false", "no", "off"}:
return False
return default