188 lines
6.3 KiB
Python
188 lines
6.3 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
实际OCR提取 + 分类测试
|
|||
|
|
流程:
|
|||
|
|
1. 调用百度OCR提取 医疗报告智能体 文件夹下的PDF
|
|||
|
|
2. 用 parse_medical_data_v2 解析OCR文本
|
|||
|
|
3. 用 classify_abb_module 对每个项目分类
|
|||
|
|
4. 输出分类结果统计
|
|||
|
|
"""
|
|||
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
import io
|
|||
|
|
import json
|
|||
|
|
import time
|
|||
|
|
|
|||
|
|
# 修复 Windows 终端 UTF-8
|
|||
|
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
|
|||
|
|
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
|
|||
|
|
|
|||
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|||
|
|
|
|||
|
|
from pathlib import Path
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
load_dotenv(Path(__file__).parent / ".env")
|
|||
|
|
|
|||
|
|
from parse_medical_v2 import parse_medical_data_v2, clean_extracted_data_v2
|
|||
|
|
from extract_and_fill_report import (
|
|||
|
|
extract_pdf_text, classify_abb_module, match_with_template,
|
|||
|
|
extract_patient_info
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
pdf_dir = Path(r"c:\Users\UI\Desktop\医疗报告\医疗报告智能体")
|
|||
|
|
config_path = Path(__file__).parent / "abb_mapping_config.json"
|
|||
|
|
|
|||
|
|
pdf_files = list(pdf_dir.glob("*.pdf"))
|
|||
|
|
if not pdf_files:
|
|||
|
|
print("[ERROR] 没有找到PDF文件")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
print("=" * 70)
|
|||
|
|
print(" 百度OCR提取 + 分类测试")
|
|||
|
|
print("=" * 70)
|
|||
|
|
|
|||
|
|
# ========== 步骤1: OCR提取 ==========
|
|||
|
|
all_items = []
|
|||
|
|
for pdf_file in pdf_files:
|
|||
|
|
print(f"\n📄 OCR提取: {pdf_file.name} ({pdf_file.stat().st_size / 1024:.0f} KB)")
|
|||
|
|
start = time.time()
|
|||
|
|
text = extract_pdf_text(str(pdf_file))
|
|||
|
|
elapsed = time.time() - start
|
|||
|
|
lines = [l for l in text.split('\n') if l.strip()]
|
|||
|
|
print(f" ✓ OCR完成 | 耗时: {elapsed:.1f}s | 行数: {len(lines)}")
|
|||
|
|
|
|||
|
|
# 保存OCR原文用于调试
|
|||
|
|
ocr_output = Path(__file__).parent / "test_ocr_raw_text.txt"
|
|||
|
|
with open(ocr_output, 'w', encoding='utf-8') as f:
|
|||
|
|
f.write(text)
|
|||
|
|
print(f" ✓ OCR原文已保存: {ocr_output.name}")
|
|||
|
|
|
|||
|
|
# 提取患者信息
|
|||
|
|
patient_info = extract_patient_info(text)
|
|||
|
|
print(f"\n 患者信息:")
|
|||
|
|
print(f" 姓名: {patient_info.get('name', '未提取')}")
|
|||
|
|
print(f" 性别: {patient_info.get('gender', '未提取')}")
|
|||
|
|
print(f" 年龄: {patient_info.get('age', '未提取')}")
|
|||
|
|
|
|||
|
|
# 解析检测项
|
|||
|
|
items = parse_medical_data_v2(text, pdf_file.name)
|
|||
|
|
items = clean_extracted_data_v2(items)
|
|||
|
|
print(f"\n ✓ 解析出 {len(items)} 个检测项")
|
|||
|
|
all_items.extend(items)
|
|||
|
|
|
|||
|
|
if not all_items:
|
|||
|
|
print("\n[WARN] 未提取到任何检测项,OCR结果可能为非血液检测报告")
|
|||
|
|
print(" 请检查 test_ocr_raw_text.txt 查看OCR原文")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# ========== 步骤2: 分类测试 ==========
|
|||
|
|
print("\n" + "=" * 70)
|
|||
|
|
print(f" 分类测试 ({len(all_items)} 个检测项)")
|
|||
|
|
print("=" * 70)
|
|||
|
|
|
|||
|
|
# 按模块分组
|
|||
|
|
by_module = {}
|
|||
|
|
unclassified = []
|
|||
|
|
|
|||
|
|
for item in all_items:
|
|||
|
|
abb = item.get('abb', '')
|
|||
|
|
project = item.get('project', abb)
|
|||
|
|
result = item.get('result', '')
|
|||
|
|
module = classify_abb_module(abb, project, api_key=None)
|
|||
|
|
|
|||
|
|
item['classified_module'] = module
|
|||
|
|
|
|||
|
|
if module == 'Other':
|
|||
|
|
unclassified.append(item)
|
|||
|
|
else:
|
|||
|
|
if module not in by_module:
|
|||
|
|
by_module[module] = []
|
|||
|
|
by_module[module].append(item)
|
|||
|
|
|
|||
|
|
# 打印每个模块的项目
|
|||
|
|
print(f"\n 分类成功: {len(all_items) - len(unclassified)} 个")
|
|||
|
|
print(f" 未分类(Other): {len(unclassified)} 个")
|
|||
|
|
|
|||
|
|
# 按模块显示
|
|||
|
|
print("\n" + "-" * 70)
|
|||
|
|
for module, items in sorted(by_module.items()):
|
|||
|
|
print(f"\n 📁 [{module}] ({len(items)} 项)")
|
|||
|
|
for item in items:
|
|||
|
|
abb = item.get('abb', '?')
|
|||
|
|
project = item.get('project', '')[:30]
|
|||
|
|
result = item.get('result', '')[:15]
|
|||
|
|
point = item.get('point', '')
|
|||
|
|
print(f" {abb:<15} {project:<32} = {result:<15} {point}")
|
|||
|
|
|
|||
|
|
if unclassified:
|
|||
|
|
print(f"\n ⚠️ [Other - 未分类] ({len(unclassified)} 项)")
|
|||
|
|
for item in unclassified:
|
|||
|
|
abb = item.get('abb', '?')
|
|||
|
|
project = item.get('project', '')[:40]
|
|||
|
|
result = item.get('result', '')[:15]
|
|||
|
|
print(f" {abb:<15} {project:<42} = {result}")
|
|||
|
|
|
|||
|
|
# ========== 步骤3: 模板匹配测试 ==========
|
|||
|
|
print("\n" + "=" * 70)
|
|||
|
|
print(" 模板匹配测试")
|
|||
|
|
print("=" * 70)
|
|||
|
|
|
|||
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|||
|
|
config = json.load(f)
|
|||
|
|
|
|||
|
|
matched = match_with_template(all_items, config)
|
|||
|
|
print(f" 模板匹配: {len(matched)} 个项目")
|
|||
|
|
|
|||
|
|
# 统计
|
|||
|
|
module_count = {}
|
|||
|
|
for abb, data in matched.items():
|
|||
|
|
module = data.get('module', '')
|
|||
|
|
if not module:
|
|||
|
|
module = classify_abb_module(abb, data.get('project', abb), api_key=None)
|
|||
|
|
if module not in module_count:
|
|||
|
|
module_count[module] = 0
|
|||
|
|
module_count[module] += 1
|
|||
|
|
|
|||
|
|
print("\n 模块分布:")
|
|||
|
|
for module, count in sorted(module_count.items(), key=lambda x: -x[1]):
|
|||
|
|
print(f" {module:<30} {count} 项")
|
|||
|
|
|
|||
|
|
# ========== 汇总 ==========
|
|||
|
|
print("\n" + "=" * 70)
|
|||
|
|
print(" 汇总")
|
|||
|
|
print("=" * 70)
|
|||
|
|
total = len(all_items)
|
|||
|
|
classified = total - len(unclassified)
|
|||
|
|
rate = classified / total * 100 if total else 0
|
|||
|
|
print(f" 总提取项: {total}")
|
|||
|
|
print(f" 分类成功: {classified} ({rate:.1f}%)")
|
|||
|
|
print(f" 未分类: {len(unclassified)}")
|
|||
|
|
print(f" 模块数: {len(by_module)}")
|
|||
|
|
print("=" * 70)
|
|||
|
|
|
|||
|
|
# 保存结果
|
|||
|
|
output_path = Path(__file__).parent / "test_ocr_classify_result.json"
|
|||
|
|
save_data = {
|
|||
|
|
"total_items": total,
|
|||
|
|
"classified": classified,
|
|||
|
|
"unclassified_count": len(unclassified),
|
|||
|
|
"modules": {m: len(items) for m, items in by_module.items()},
|
|||
|
|
"items": [{
|
|||
|
|
"abb": item.get("abb", ""),
|
|||
|
|
"project": item.get("project", ""),
|
|||
|
|
"result": item.get("result", ""),
|
|||
|
|
"point": item.get("point", ""),
|
|||
|
|
"unit": item.get("unit", ""),
|
|||
|
|
"module": item.get("classified_module", "Other")
|
|||
|
|
} for item in all_items]
|
|||
|
|
}
|
|||
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(save_data, f, ensure_ascii=False, indent=2)
|
|||
|
|
print(f"\n 结果已保存: {output_path.name}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|