更新
This commit is contained in:
209
difyPlugin/pdf/tools/pdf_summary.py
Normal file
209
difyPlugin/pdf/tools/pdf_summary.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import fitz
|
||||
from dify_plugin import Tool
|
||||
from dify_plugin.entities.model.llm import LLMModelConfig
|
||||
from dify_plugin.entities.model.message import SystemPromptMessage, UserPromptMessage
|
||||
from dify_plugin.entities.tool import ToolInvokeMessage
|
||||
|
||||
|
||||
class PdfSummaryTool(Tool):
|
||||
"""Fast PDF page summary tool.
|
||||
|
||||
Default behavior is optimized for throughput in large workflows:
|
||||
- Extract plain text and lightweight table data only.
|
||||
- Skip expensive image base64 and drawing path extraction.
|
||||
- Skip LLM by default unless `use_llm=true` is explicitly passed.
|
||||
"""
|
||||
|
||||
def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
|
||||
file = tool_parameters.get("file")
|
||||
if not file:
|
||||
yield self.create_text_message("Error: file is required")
|
||||
return
|
||||
|
||||
start_page = self._to_int(tool_parameters.get("pdf_start_page"), 0)
|
||||
end_page = self._to_int(tool_parameters.get("pdf_end_page"), 0)
|
||||
model_config = tool_parameters.get("model")
|
||||
use_llm = self._to_bool(tool_parameters.get("use_llm"), False)
|
||||
|
||||
max_chars_per_page = self._to_int(tool_parameters.get("max_chars_per_page"), 6000)
|
||||
max_chars_per_page = max(800, min(max_chars_per_page, 20000))
|
||||
|
||||
llm_prompt = tool_parameters.get(
|
||||
"llm_prompt",
|
||||
"请基于输入的PDF页面文本做简洁准确摘要,输出中文要点。不要输出思考过程。",
|
||||
)
|
||||
|
||||
pdf_bytes = file.blob
|
||||
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
||||
try:
|
||||
num_pages = len(doc)
|
||||
start_page = max(0, min(start_page, num_pages - 1))
|
||||
end_page = max(start_page, min(end_page, num_pages - 1))
|
||||
|
||||
pages_data: list[dict[str, Any]] = []
|
||||
for page_idx in range(start_page, end_page + 1):
|
||||
page = doc[page_idx]
|
||||
page_data = self._extract_page_fast(page, page_idx, max_chars_per_page)
|
||||
pages_data.append(page_data)
|
||||
|
||||
result = {
|
||||
"total_pages_extracted": len(pages_data),
|
||||
"page_range": {"start": start_page, "end": end_page},
|
||||
"pages": pages_data,
|
||||
}
|
||||
yield self.create_json_message(result)
|
||||
|
||||
# Fast local summary first (deterministic, no model latency)
|
||||
local_text = self._build_local_summary(pages_data)
|
||||
|
||||
# Optional LLM refinement, explicitly enabled only
|
||||
if use_llm and model_config:
|
||||
refined = self._summarize_with_llm(local_text, llm_prompt, model_config)
|
||||
final_text = refined if refined else local_text
|
||||
else:
|
||||
final_text = local_text
|
||||
|
||||
if final_text:
|
||||
yield self.create_text_message(final_text)
|
||||
finally:
|
||||
doc.close()
|
||||
|
||||
def _extract_page_fast(self, page: fitz.Page, page_idx: int, max_chars_per_page: int) -> dict[str, Any]:
|
||||
text = (page.get_text("text") or "").strip()
|
||||
if len(text) > max_chars_per_page:
|
||||
text = text[:max_chars_per_page] + "\n...[truncated]"
|
||||
|
||||
tables: list[dict[str, Any]] = []
|
||||
try:
|
||||
tabs = page.find_tables()
|
||||
for tab_idx, tab in enumerate(tabs.tables[:3]):
|
||||
cells = tab.extract() or []
|
||||
tables.append(
|
||||
{
|
||||
"index": tab_idx,
|
||||
"rows": tab.row_count,
|
||||
"cols": tab.col_count,
|
||||
"cells": cells[:10],
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"page_number": page_idx,
|
||||
"text": text,
|
||||
"tables": tables,
|
||||
"images": [],
|
||||
"drawings_summary": [],
|
||||
"text_blocks": [],
|
||||
"width": float(page.rect.width),
|
||||
"height": float(page.rect.height),
|
||||
}
|
||||
|
||||
def _build_local_summary(self, pages_data: list[dict[str, Any]]) -> str:
|
||||
"""Output actual page content as Markdown (text + tables).
|
||||
|
||||
No LLM needed downstream — the text is already usable Markdown.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for page in pages_data:
|
||||
text = (page.get("text") or "").strip()
|
||||
tables = page.get("tables") or []
|
||||
|
||||
page_parts: list[str] = []
|
||||
if text:
|
||||
page_parts.append(text)
|
||||
|
||||
for tab in tables:
|
||||
cells = tab.get("cells") or []
|
||||
if len(cells) >= 2:
|
||||
md = self._cells_to_md_table(cells)
|
||||
if md:
|
||||
page_parts.append(md)
|
||||
|
||||
if page_parts:
|
||||
parts.append("\n\n".join(page_parts))
|
||||
|
||||
return "\n\n--- 分页 ---\n\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _cells_to_md_table(cells: list) -> str:
|
||||
if not cells:
|
||||
return ""
|
||||
header = cells[0]
|
||||
ncols = len(header)
|
||||
if ncols == 0:
|
||||
return ""
|
||||
clean = lambda c: str(c or "").replace("|", "\\|").replace("\n", " ").strip()
|
||||
lines = [
|
||||
"| " + " | ".join(clean(c) for c in header) + " |",
|
||||
"| " + " | ".join("---" for _ in range(ncols)) + " |",
|
||||
]
|
||||
for row in cells[1:]:
|
||||
padded = list(row) + [""] * max(0, ncols - len(row))
|
||||
lines.append("| " + " | ".join(clean(c) for c in padded[:ncols]) + " |")
|
||||
return "\n".join(lines)
|
||||
|
||||
def _summarize_with_llm(self, local_text: str, llm_prompt: str, model_config: dict[str, Any]) -> str:
|
||||
response = self.session.model.llm.invoke(
|
||||
model_config=LLMModelConfig(**model_config),
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content=llm_prompt),
|
||||
UserPromptMessage(content=local_text),
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
llm_text = ""
|
||||
if hasattr(response, "message") and response.message:
|
||||
content = response.message.content
|
||||
if isinstance(content, str):
|
||||
llm_text = content
|
||||
elif isinstance(content, list):
|
||||
llm_text = "".join(
|
||||
item.data if hasattr(item, "data") else str(item)
|
||||
for item in content
|
||||
)
|
||||
|
||||
return self._extract_visible_answer(llm_text)
|
||||
|
||||
@staticmethod
|
||||
def _extract_visible_answer(text: str) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
box_match = re.search(r"<\|begin_of_box\|>([\s\S]*?)<\|end_of_box\|>", text)
|
||||
if box_match:
|
||||
text = box_match.group(1)
|
||||
else:
|
||||
text = re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE)
|
||||
|
||||
text = re.sub(r"<\|[^>]+\|>", "", text)
|
||||
return text.strip()
|
||||
|
||||
@staticmethod
|
||||
def _to_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
if value is None or value == "":
|
||||
return default
|
||||
return int(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _to_bool(value: Any, default: bool) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
s = str(value).strip().lower()
|
||||
if s in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if s in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
return default
|
||||
Reference in New Issue
Block a user