274 lines
8.8 KiB
Python
274 lines
8.8 KiB
Python
import json
|
||
import re
|
||
from collections import OrderedDict
|
||
from collections.abc import Generator
|
||
from typing import Any
|
||
|
||
from dify_plugin import Tool
|
||
from dify_plugin.entities.model.llm import LLMModelConfig
|
||
from dify_plugin.entities.model.message import SystemPromptMessage, UserPromptMessage
|
||
from dify_plugin.entities.tool import ToolInvokeMessage
|
||
|
||
_SYSTEM_PROMPT = """You parse PDF table-of-contents text.
|
||
Return only valid JSON object, no markdown fences, no explanation.
|
||
Output schema:
|
||
{
|
||
"Chapter Name": {"start": 1, "end": 5},
|
||
"Another": {"start": 6, "end": 20}
|
||
}
|
||
Rules:
|
||
- start/end are integer printed page numbers from TOC.
|
||
- If end is unknown, use same value as start.
|
||
- Keep chapter names exactly as in TOC text.
|
||
"""
|
||
|
||
|
||
class PdfTocTool(Tool):
|
||
def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
|
||
toc_start = self._to_int(tool_parameters.get("toc_start"), None)
|
||
toc_end = self._to_int(tool_parameters.get("toc_end"), None)
|
||
toc_pages = (tool_parameters.get("toc_pages") or "").strip()
|
||
model_config = tool_parameters.get("model")
|
||
|
||
if toc_start is None or toc_end is None:
|
||
yield self.create_text_message("Error: toc_start and toc_end are required")
|
||
return
|
||
|
||
if not toc_pages:
|
||
yield self.create_text_message("Error: toc_pages text is empty")
|
||
return
|
||
|
||
cleaned = self._strip_index_lists(toc_pages)
|
||
|
||
# 1) deterministic parser first
|
||
catalog = self._parse_toc_lines(cleaned)
|
||
|
||
# 2) optional LLM fallback/enhance only when deterministic parser gives no result
|
||
llm_raw_output = ""
|
||
llm_error = None
|
||
if not catalog and model_config:
|
||
llm_catalog, llm_raw_output, llm_error = self._parse_with_llm(
|
||
toc_start=toc_start,
|
||
toc_end=toc_end,
|
||
toc_pages=cleaned,
|
||
model_config=model_config,
|
||
)
|
||
if llm_catalog:
|
||
catalog = self._normalize_catalog(llm_catalog)
|
||
|
||
result: dict[str, Any] = {
|
||
"toc_start": toc_start,
|
||
"toc_end": toc_end,
|
||
"catalog": catalog,
|
||
"meta": {
|
||
"catalog_size": len(catalog),
|
||
"parser": "rule" if catalog else "none",
|
||
},
|
||
}
|
||
|
||
if llm_raw_output:
|
||
result["meta"]["llm_used"] = True
|
||
if llm_error:
|
||
result["meta"]["llm_error"] = llm_error
|
||
|
||
# always return valid json text payload for downstream json.loads
|
||
yield self.create_text_message(json.dumps(result, ensure_ascii=False))
|
||
yield self.create_json_message(result)
|
||
|
||
def _parse_with_llm(
|
||
self,
|
||
toc_start: int,
|
||
toc_end: int,
|
||
toc_pages: str,
|
||
model_config: dict[str, Any],
|
||
) -> tuple[dict[str, Any] | None, str, str | None]:
|
||
user_content = (
|
||
f"TOC page index range: {toc_start}..{toc_end}\n\n"
|
||
f"TOC raw text:\n{toc_pages}"
|
||
)
|
||
response = self.session.model.llm.invoke(
|
||
model_config=LLMModelConfig(**model_config),
|
||
prompt_messages=[
|
||
SystemPromptMessage(content=_SYSTEM_PROMPT),
|
||
UserPromptMessage(content=user_content),
|
||
],
|
||
stream=False,
|
||
)
|
||
|
||
llm_text = ""
|
||
if hasattr(response, "message") and response.message:
|
||
content = response.message.content
|
||
if isinstance(content, str):
|
||
llm_text = content
|
||
elif isinstance(content, list):
|
||
llm_text = "".join(
|
||
item.data if hasattr(item, "data") else str(item) for item in content
|
||
)
|
||
|
||
parsed = self._extract_json_object(llm_text)
|
||
if parsed is None:
|
||
return None, llm_text, "Failed to parse LLM output as JSON"
|
||
if not isinstance(parsed, dict):
|
||
return None, llm_text, "LLM output JSON is not an object"
|
||
|
||
return parsed, llm_text, None
|
||
|
||
@staticmethod
|
||
def _strip_index_lists(text: str) -> str:
|
||
# Stop before common appendix lists that pollute TOC parsing.
|
||
pattern = re.compile(
|
||
r"^(List\s+of\s+Figures|List\s+of\s+Tables|图目录|表目录)",
|
||
re.IGNORECASE | re.MULTILINE,
|
||
)
|
||
m = pattern.search(text)
|
||
return text[: m.start()].rstrip() if m else text
|
||
|
||
def _parse_toc_lines(self, text: str) -> dict[str, dict[str, int]]:
|
||
"""Parse lines like:
|
||
1.2 Engine Overview ........ 35
|
||
Appendix A 120
|
||
"""
|
||
line_pattern = re.compile(
|
||
r"^\s*(?P<title>.+?)\s*(?:\.{2,}|\s)\s*(?P<page>\d{1,5})\s*$"
|
||
)
|
||
|
||
entries: list[tuple[str, int]] = []
|
||
for raw in text.splitlines():
|
||
line = raw.strip()
|
||
if not line or len(line) < 3:
|
||
continue
|
||
if re.fullmatch(r"\d+", line):
|
||
continue
|
||
|
||
m = line_pattern.match(line)
|
||
if not m:
|
||
continue
|
||
|
||
title = re.sub(r"\s+", " ", m.group("title")).strip("-_:: ")
|
||
page = self._to_int(m.group("page"), None)
|
||
if not title or page is None:
|
||
continue
|
||
|
||
# Skip obvious noise.
|
||
if len(title) <= 1 or title.lower() in {"page", "pages", "目录", "contents"}:
|
||
continue
|
||
|
||
entries.append((title, page))
|
||
|
||
if not entries:
|
||
return {}
|
||
|
||
# Deduplicate keeping earliest appearance.
|
||
dedup: OrderedDict[str, int] = OrderedDict()
|
||
for title, page in entries:
|
||
if title not in dedup:
|
||
dedup[title] = page
|
||
|
||
titles = list(dedup.keys())
|
||
pages = [dedup[t] for t in titles]
|
||
|
||
catalog: dict[str, dict[str, int]] = {}
|
||
for i, title in enumerate(titles):
|
||
start = pages[i]
|
||
if i + 1 < len(pages):
|
||
next_start = pages[i + 1]
|
||
end = max(start, next_start - 1)
|
||
else:
|
||
end = start
|
||
catalog[title] = {"start": int(start), "end": int(end)}
|
||
|
||
return catalog
|
||
|
||
def _normalize_catalog(self, raw: dict[str, Any]) -> dict[str, dict[str, int]]:
|
||
catalog: dict[str, dict[str, int]] = {}
|
||
source = raw.get("catalog") if isinstance(raw.get("catalog"), dict) else raw
|
||
if not isinstance(source, dict):
|
||
return catalog
|
||
|
||
for name, value in source.items():
|
||
if not isinstance(name, str) or not isinstance(value, dict):
|
||
continue
|
||
start = self._to_int(value.get("start"), None)
|
||
end = self._to_int(value.get("end"), start)
|
||
if start is None:
|
||
continue
|
||
if end is None:
|
||
end = start
|
||
catalog[name] = {"start": int(start), "end": int(max(start, end))}
|
||
return catalog
|
||
|
||
@staticmethod
|
||
def _extract_json_object(text: str) -> Any:
|
||
if not text:
|
||
return None
|
||
|
||
candidates: list[str] = []
|
||
|
||
code_blocks = re.findall(r"```(?:json)?\s*([\s\S]*?)\s*```", text, flags=re.IGNORECASE)
|
||
candidates.extend([c.strip() for c in code_blocks if c.strip()])
|
||
|
||
brace_candidate = PdfTocTool._extract_first_brace_object(text)
|
||
if brace_candidate:
|
||
candidates.append(brace_candidate)
|
||
|
||
candidates.append(text.strip())
|
||
|
||
for cand in candidates:
|
||
parsed = PdfTocTool._json_try_parse(cand)
|
||
if parsed is not None:
|
||
return parsed
|
||
return None
|
||
|
||
@staticmethod
|
||
def _extract_first_brace_object(text: str) -> str | None:
|
||
start = text.find("{")
|
||
if start < 0:
|
||
return None
|
||
|
||
depth = 0
|
||
in_str = False
|
||
escape = False
|
||
for i in range(start, len(text)):
|
||
ch = text[i]
|
||
if in_str:
|
||
if escape:
|
||
escape = False
|
||
elif ch == "\\":
|
||
escape = True
|
||
elif ch == '"':
|
||
in_str = False
|
||
continue
|
||
|
||
if ch == '"':
|
||
in_str = True
|
||
elif ch == "{":
|
||
depth += 1
|
||
elif ch == "}":
|
||
depth -= 1
|
||
if depth == 0:
|
||
return text[start : i + 1]
|
||
return None
|
||
|
||
@staticmethod
|
||
def _json_try_parse(text: str) -> Any:
|
||
try:
|
||
return json.loads(text)
|
||
except Exception:
|
||
pass
|
||
|
||
# Minimal repair: remove trailing commas before } or ]
|
||
repaired = re.sub(r",\s*([}\]])", r"\1", text)
|
||
try:
|
||
return json.loads(repaired)
|
||
except Exception:
|
||
return None
|
||
|
||
@staticmethod
|
||
def _to_int(value: Any, default: int | None) -> int | None:
|
||
try:
|
||
if value is None or value == "":
|
||
return default
|
||
return int(value)
|
||
except Exception:
|
||
return default
|