210 lines
7.3 KiB
Python
210 lines
7.3 KiB
Python
|
|
import json
|
|||
|
|
import re
|
|||
|
|
from collections.abc import Generator
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
import fitz
|
|||
|
|
from dify_plugin import Tool
|
|||
|
|
from dify_plugin.entities.model.llm import LLMModelConfig
|
|||
|
|
from dify_plugin.entities.model.message import SystemPromptMessage, UserPromptMessage
|
|||
|
|
from dify_plugin.entities.tool import ToolInvokeMessage
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PdfSummaryTool(Tool):
|
|||
|
|
"""Fast PDF page summary tool.
|
|||
|
|
|
|||
|
|
Default behavior is optimized for throughput in large workflows:
|
|||
|
|
- Extract plain text and lightweight table data only.
|
|||
|
|
- Skip expensive image base64 and drawing path extraction.
|
|||
|
|
- Skip LLM by default unless `use_llm=true` is explicitly passed.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
|
|||
|
|
file = tool_parameters.get("file")
|
|||
|
|
if not file:
|
|||
|
|
yield self.create_text_message("Error: file is required")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
start_page = self._to_int(tool_parameters.get("pdf_start_page"), 0)
|
|||
|
|
end_page = self._to_int(tool_parameters.get("pdf_end_page"), 0)
|
|||
|
|
model_config = tool_parameters.get("model")
|
|||
|
|
use_llm = self._to_bool(tool_parameters.get("use_llm"), False)
|
|||
|
|
|
|||
|
|
max_chars_per_page = self._to_int(tool_parameters.get("max_chars_per_page"), 6000)
|
|||
|
|
max_chars_per_page = max(800, min(max_chars_per_page, 20000))
|
|||
|
|
|
|||
|
|
llm_prompt = tool_parameters.get(
|
|||
|
|
"llm_prompt",
|
|||
|
|
"请基于输入的PDF页面文本做简洁准确摘要,输出中文要点。不要输出思考过程。",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
pdf_bytes = file.blob
|
|||
|
|
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
|||
|
|
try:
|
|||
|
|
num_pages = len(doc)
|
|||
|
|
start_page = max(0, min(start_page, num_pages - 1))
|
|||
|
|
end_page = max(start_page, min(end_page, num_pages - 1))
|
|||
|
|
|
|||
|
|
pages_data: list[dict[str, Any]] = []
|
|||
|
|
for page_idx in range(start_page, end_page + 1):
|
|||
|
|
page = doc[page_idx]
|
|||
|
|
page_data = self._extract_page_fast(page, page_idx, max_chars_per_page)
|
|||
|
|
pages_data.append(page_data)
|
|||
|
|
|
|||
|
|
result = {
|
|||
|
|
"total_pages_extracted": len(pages_data),
|
|||
|
|
"page_range": {"start": start_page, "end": end_page},
|
|||
|
|
"pages": pages_data,
|
|||
|
|
}
|
|||
|
|
yield self.create_json_message(result)
|
|||
|
|
|
|||
|
|
# Fast local summary first (deterministic, no model latency)
|
|||
|
|
local_text = self._build_local_summary(pages_data)
|
|||
|
|
|
|||
|
|
# Optional LLM refinement, explicitly enabled only
|
|||
|
|
if use_llm and model_config:
|
|||
|
|
refined = self._summarize_with_llm(local_text, llm_prompt, model_config)
|
|||
|
|
final_text = refined if refined else local_text
|
|||
|
|
else:
|
|||
|
|
final_text = local_text
|
|||
|
|
|
|||
|
|
if final_text:
|
|||
|
|
yield self.create_text_message(final_text)
|
|||
|
|
finally:
|
|||
|
|
doc.close()
|
|||
|
|
|
|||
|
|
def _extract_page_fast(self, page: fitz.Page, page_idx: int, max_chars_per_page: int) -> dict[str, Any]:
|
|||
|
|
text = (page.get_text("text") or "").strip()
|
|||
|
|
if len(text) > max_chars_per_page:
|
|||
|
|
text = text[:max_chars_per_page] + "\n...[truncated]"
|
|||
|
|
|
|||
|
|
tables: list[dict[str, Any]] = []
|
|||
|
|
try:
|
|||
|
|
tabs = page.find_tables()
|
|||
|
|
for tab_idx, tab in enumerate(tabs.tables[:3]):
|
|||
|
|
cells = tab.extract() or []
|
|||
|
|
tables.append(
|
|||
|
|
{
|
|||
|
|
"index": tab_idx,
|
|||
|
|
"rows": tab.row_count,
|
|||
|
|
"cols": tab.col_count,
|
|||
|
|
"cells": cells[:10],
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"page_number": page_idx,
|
|||
|
|
"text": text,
|
|||
|
|
"tables": tables,
|
|||
|
|
"images": [],
|
|||
|
|
"drawings_summary": [],
|
|||
|
|
"text_blocks": [],
|
|||
|
|
"width": float(page.rect.width),
|
|||
|
|
"height": float(page.rect.height),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def _build_local_summary(self, pages_data: list[dict[str, Any]]) -> str:
|
|||
|
|
"""Output actual page content as Markdown (text + tables).
|
|||
|
|
|
|||
|
|
No LLM needed downstream — the text is already usable Markdown.
|
|||
|
|
"""
|
|||
|
|
parts: list[str] = []
|
|||
|
|
for page in pages_data:
|
|||
|
|
text = (page.get("text") or "").strip()
|
|||
|
|
tables = page.get("tables") or []
|
|||
|
|
|
|||
|
|
page_parts: list[str] = []
|
|||
|
|
if text:
|
|||
|
|
page_parts.append(text)
|
|||
|
|
|
|||
|
|
for tab in tables:
|
|||
|
|
cells = tab.get("cells") or []
|
|||
|
|
if len(cells) >= 2:
|
|||
|
|
md = self._cells_to_md_table(cells)
|
|||
|
|
if md:
|
|||
|
|
page_parts.append(md)
|
|||
|
|
|
|||
|
|
if page_parts:
|
|||
|
|
parts.append("\n\n".join(page_parts))
|
|||
|
|
|
|||
|
|
return "\n\n--- 分页 ---\n\n".join(parts)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _cells_to_md_table(cells: list) -> str:
|
|||
|
|
if not cells:
|
|||
|
|
return ""
|
|||
|
|
header = cells[0]
|
|||
|
|
ncols = len(header)
|
|||
|
|
if ncols == 0:
|
|||
|
|
return ""
|
|||
|
|
clean = lambda c: str(c or "").replace("|", "\\|").replace("\n", " ").strip()
|
|||
|
|
lines = [
|
|||
|
|
"| " + " | ".join(clean(c) for c in header) + " |",
|
|||
|
|
"| " + " | ".join("---" for _ in range(ncols)) + " |",
|
|||
|
|
]
|
|||
|
|
for row in cells[1:]:
|
|||
|
|
padded = list(row) + [""] * max(0, ncols - len(row))
|
|||
|
|
lines.append("| " + " | ".join(clean(c) for c in padded[:ncols]) + " |")
|
|||
|
|
return "\n".join(lines)
|
|||
|
|
|
|||
|
|
def _summarize_with_llm(self, local_text: str, llm_prompt: str, model_config: dict[str, Any]) -> str:
|
|||
|
|
response = self.session.model.llm.invoke(
|
|||
|
|
model_config=LLMModelConfig(**model_config),
|
|||
|
|
prompt_messages=[
|
|||
|
|
SystemPromptMessage(content=llm_prompt),
|
|||
|
|
UserPromptMessage(content=local_text),
|
|||
|
|
],
|
|||
|
|
stream=False,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
llm_text = ""
|
|||
|
|
if hasattr(response, "message") and response.message:
|
|||
|
|
content = response.message.content
|
|||
|
|
if isinstance(content, str):
|
|||
|
|
llm_text = content
|
|||
|
|
elif isinstance(content, list):
|
|||
|
|
llm_text = "".join(
|
|||
|
|
item.data if hasattr(item, "data") else str(item)
|
|||
|
|
for item in content
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return self._extract_visible_answer(llm_text)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _extract_visible_answer(text: str) -> str:
|
|||
|
|
if not text:
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
box_match = re.search(r"<\|begin_of_box\|>([\s\S]*?)<\|end_of_box\|>", text)
|
|||
|
|
if box_match:
|
|||
|
|
text = box_match.group(1)
|
|||
|
|
else:
|
|||
|
|
text = re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE)
|
|||
|
|
|
|||
|
|
text = re.sub(r"<\|[^>]+\|>", "", text)
|
|||
|
|
return text.strip()
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _to_int(value: Any, default: int) -> int:
|
|||
|
|
try:
|
|||
|
|
if value is None or value == "":
|
|||
|
|
return default
|
|||
|
|
return int(value)
|
|||
|
|
except Exception:
|
|||
|
|
return default
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _to_bool(value: Any, default: bool) -> bool:
|
|||
|
|
if value is None:
|
|||
|
|
return default
|
|||
|
|
if isinstance(value, bool):
|
|||
|
|
return value
|
|||
|
|
s = str(value).strip().lower()
|
|||
|
|
if s in {"1", "true", "yes", "on"}:
|
|||
|
|
return True
|
|||
|
|
if s in {"0", "false", "no", "off"}:
|
|||
|
|
return False
|
|||
|
|
return default
|