210 lines
7.3 KiB
Python
210 lines
7.3 KiB
Python
import json
|
||
import re
|
||
from collections.abc import Generator
|
||
from typing import Any
|
||
|
||
import fitz
|
||
from dify_plugin import Tool
|
||
from dify_plugin.entities.model.llm import LLMModelConfig
|
||
from dify_plugin.entities.model.message import SystemPromptMessage, UserPromptMessage
|
||
from dify_plugin.entities.tool import ToolInvokeMessage
|
||
|
||
|
||
class PdfSummaryTool(Tool):
|
||
"""Fast PDF page summary tool.
|
||
|
||
Default behavior is optimized for throughput in large workflows:
|
||
- Extract plain text and lightweight table data only.
|
||
- Skip expensive image base64 and drawing path extraction.
|
||
- Skip LLM by default unless `use_llm=true` is explicitly passed.
|
||
"""
|
||
|
||
def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
|
||
file = tool_parameters.get("file")
|
||
if not file:
|
||
yield self.create_text_message("Error: file is required")
|
||
return
|
||
|
||
start_page = self._to_int(tool_parameters.get("pdf_start_page"), 0)
|
||
end_page = self._to_int(tool_parameters.get("pdf_end_page"), 0)
|
||
model_config = tool_parameters.get("model")
|
||
use_llm = self._to_bool(tool_parameters.get("use_llm"), False)
|
||
|
||
max_chars_per_page = self._to_int(tool_parameters.get("max_chars_per_page"), 6000)
|
||
max_chars_per_page = max(800, min(max_chars_per_page, 20000))
|
||
|
||
llm_prompt = tool_parameters.get(
|
||
"llm_prompt",
|
||
"请基于输入的PDF页面文本做简洁准确摘要,输出中文要点。不要输出思考过程。",
|
||
)
|
||
|
||
pdf_bytes = file.blob
|
||
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
||
try:
|
||
num_pages = len(doc)
|
||
start_page = max(0, min(start_page, num_pages - 1))
|
||
end_page = max(start_page, min(end_page, num_pages - 1))
|
||
|
||
pages_data: list[dict[str, Any]] = []
|
||
for page_idx in range(start_page, end_page + 1):
|
||
page = doc[page_idx]
|
||
page_data = self._extract_page_fast(page, page_idx, max_chars_per_page)
|
||
pages_data.append(page_data)
|
||
|
||
result = {
|
||
"total_pages_extracted": len(pages_data),
|
||
"page_range": {"start": start_page, "end": end_page},
|
||
"pages": pages_data,
|
||
}
|
||
yield self.create_json_message(result)
|
||
|
||
# Fast local summary first (deterministic, no model latency)
|
||
local_text = self._build_local_summary(pages_data)
|
||
|
||
# Optional LLM refinement, explicitly enabled only
|
||
if use_llm and model_config:
|
||
refined = self._summarize_with_llm(local_text, llm_prompt, model_config)
|
||
final_text = refined if refined else local_text
|
||
else:
|
||
final_text = local_text
|
||
|
||
if final_text:
|
||
yield self.create_text_message(final_text)
|
||
finally:
|
||
doc.close()
|
||
|
||
def _extract_page_fast(self, page: fitz.Page, page_idx: int, max_chars_per_page: int) -> dict[str, Any]:
|
||
text = (page.get_text("text") or "").strip()
|
||
if len(text) > max_chars_per_page:
|
||
text = text[:max_chars_per_page] + "\n...[truncated]"
|
||
|
||
tables: list[dict[str, Any]] = []
|
||
try:
|
||
tabs = page.find_tables()
|
||
for tab_idx, tab in enumerate(tabs.tables[:3]):
|
||
cells = tab.extract() or []
|
||
tables.append(
|
||
{
|
||
"index": tab_idx,
|
||
"rows": tab.row_count,
|
||
"cols": tab.col_count,
|
||
"cells": cells[:10],
|
||
}
|
||
)
|
||
except Exception:
|
||
pass
|
||
|
||
return {
|
||
"page_number": page_idx,
|
||
"text": text,
|
||
"tables": tables,
|
||
"images": [],
|
||
"drawings_summary": [],
|
||
"text_blocks": [],
|
||
"width": float(page.rect.width),
|
||
"height": float(page.rect.height),
|
||
}
|
||
|
||
def _build_local_summary(self, pages_data: list[dict[str, Any]]) -> str:
|
||
"""Output actual page content as Markdown (text + tables).
|
||
|
||
No LLM needed downstream — the text is already usable Markdown.
|
||
"""
|
||
parts: list[str] = []
|
||
for page in pages_data:
|
||
text = (page.get("text") or "").strip()
|
||
tables = page.get("tables") or []
|
||
|
||
page_parts: list[str] = []
|
||
if text:
|
||
page_parts.append(text)
|
||
|
||
for tab in tables:
|
||
cells = tab.get("cells") or []
|
||
if len(cells) >= 2:
|
||
md = self._cells_to_md_table(cells)
|
||
if md:
|
||
page_parts.append(md)
|
||
|
||
if page_parts:
|
||
parts.append("\n\n".join(page_parts))
|
||
|
||
return "\n\n--- 分页 ---\n\n".join(parts)
|
||
|
||
@staticmethod
|
||
def _cells_to_md_table(cells: list) -> str:
|
||
if not cells:
|
||
return ""
|
||
header = cells[0]
|
||
ncols = len(header)
|
||
if ncols == 0:
|
||
return ""
|
||
clean = lambda c: str(c or "").replace("|", "\\|").replace("\n", " ").strip()
|
||
lines = [
|
||
"| " + " | ".join(clean(c) for c in header) + " |",
|
||
"| " + " | ".join("---" for _ in range(ncols)) + " |",
|
||
]
|
||
for row in cells[1:]:
|
||
padded = list(row) + [""] * max(0, ncols - len(row))
|
||
lines.append("| " + " | ".join(clean(c) for c in padded[:ncols]) + " |")
|
||
return "\n".join(lines)
|
||
|
||
def _summarize_with_llm(self, local_text: str, llm_prompt: str, model_config: dict[str, Any]) -> str:
|
||
response = self.session.model.llm.invoke(
|
||
model_config=LLMModelConfig(**model_config),
|
||
prompt_messages=[
|
||
SystemPromptMessage(content=llm_prompt),
|
||
UserPromptMessage(content=local_text),
|
||
],
|
||
stream=False,
|
||
)
|
||
|
||
llm_text = ""
|
||
if hasattr(response, "message") and response.message:
|
||
content = response.message.content
|
||
if isinstance(content, str):
|
||
llm_text = content
|
||
elif isinstance(content, list):
|
||
llm_text = "".join(
|
||
item.data if hasattr(item, "data") else str(item)
|
||
for item in content
|
||
)
|
||
|
||
return self._extract_visible_answer(llm_text)
|
||
|
||
@staticmethod
|
||
def _extract_visible_answer(text: str) -> str:
|
||
if not text:
|
||
return ""
|
||
|
||
box_match = re.search(r"<\|begin_of_box\|>([\s\S]*?)<\|end_of_box\|>", text)
|
||
if box_match:
|
||
text = box_match.group(1)
|
||
else:
|
||
text = re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE)
|
||
|
||
text = re.sub(r"<\|[^>]+\|>", "", text)
|
||
return text.strip()
|
||
|
||
@staticmethod
|
||
def _to_int(value: Any, default: int) -> int:
|
||
try:
|
||
if value is None or value == "":
|
||
return default
|
||
return int(value)
|
||
except Exception:
|
||
return default
|
||
|
||
@staticmethod
|
||
def _to_bool(value: Any, default: bool) -> bool:
|
||
if value is None:
|
||
return default
|
||
if isinstance(value, bool):
|
||
return value
|
||
s = str(value).strip().lower()
|
||
if s in {"1", "true", "yes", "on"}:
|
||
return True
|
||
if s in {"0", "false", "no", "off"}:
|
||
return False
|
||
return default
|