107 lines
3.2 KiB
Python
107 lines
3.2 KiB
Python
import json
|
||
import re
|
||
from collections.abc import Generator
|
||
from io import BytesIO
|
||
from typing import Any
|
||
|
||
import fitz # PyMuPDF 核心库
|
||
from dify_plugin import Tool
|
||
from dify_plugin.entities.tool import ToolInvokeMessage
|
||
|
||
|
||
class PdfTool(Tool):
|
||
def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
|
||
file = tool_parameters.get("file")
|
||
if not file:
|
||
yield self.create_text_message("Error: file is required")
|
||
return
|
||
|
||
# 从字节流加载 PDF(替换 PyPDF2)
|
||
pdf_bytes = file.blob
|
||
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
||
num_pages = len(doc)
|
||
|
||
toc_start = None
|
||
toc_end = None
|
||
|
||
# 目录匹配正则(与原代码一致)
|
||
toc_patterns = [
|
||
r'目录',
|
||
r'目 录',
|
||
r'目\u3000录',
|
||
r'Table of Contents',
|
||
r'Contents',
|
||
r'目次'
|
||
]
|
||
|
||
# 遍历页面识别目录页(逻辑不变,仅替换文本提取方式)
|
||
for page_num in range(num_pages):
|
||
page = doc[page_num]
|
||
text = page.get_text() or "" # PyMuPDF 提取文本
|
||
|
||
if any(re.search(pattern, text, re.IGNORECASE) for pattern in toc_patterns):
|
||
if toc_start is None:
|
||
toc_start = page_num
|
||
toc_end = page_num
|
||
elif toc_start is not None and toc_end is not None:
|
||
break
|
||
|
||
# 提取目录页文本
|
||
toc_pages = []
|
||
if toc_start is not None and toc_end is not None:
|
||
for page_num in range(toc_start, toc_end + 1):
|
||
page = doc[page_num]
|
||
toc_pages.append(page.get_text() or "")
|
||
|
||
# 关闭文档
|
||
doc.close()
|
||
|
||
result = {
|
||
"start": toc_start,
|
||
"end": toc_end,
|
||
"pages": toc_pages,
|
||
"pages_text": "\n".join(toc_pages) if toc_pages else "",
|
||
}
|
||
yield self.create_text_message(json.dumps(result, ensure_ascii=False))
|
||
yield self.create_json_message(result)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试代码(改用 PyMuPDF)
|
||
pdf_path = r"F:\Project\urbanLifeline\docs\AI训练资料\菱重S12R发动机说明书.pdf"
|
||
doc = fitz.open(pdf_path) # 本地文件直接打开
|
||
num_pages = len(doc)
|
||
|
||
toc_start = None
|
||
toc_end = None
|
||
|
||
toc_patterns = [
|
||
r'目录',
|
||
r'目 录',
|
||
r'目\u3000录',
|
||
r'Table of Contents',
|
||
r'Contents',
|
||
r'目次'
|
||
]
|
||
|
||
# 遍历页面找目录
|
||
for page_num in range(num_pages):
|
||
page = doc[page_num]
|
||
text = page.get_text() or ""
|
||
if any(re.search(pattern, text, re.IGNORECASE) for pattern in toc_patterns):
|
||
if toc_start is None:
|
||
toc_start = page_num
|
||
toc_end = page_num
|
||
elif toc_start is not None and toc_end is not None:
|
||
break
|
||
|
||
# 提取目录页文本
|
||
toc_pages = []
|
||
toc_start = toc_start if toc_start is not None else 18
|
||
toc_end = toc_end if toc_end is not None else toc_start + 9
|
||
for page_num in range(toc_start, toc_end):
|
||
page = doc[page_num]
|
||
toc_pages.append(page.get_text() or "")
|
||
|
||
print(toc_start, toc_end, toc_pages)
|
||
doc.close() # 关闭文档 |