188 lines
6.3 KiB
Python
188 lines
6.3 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
实际OCR提取 + 分类测试
|
||
流程:
|
||
1. 调用百度OCR提取 医疗报告智能体 文件夹下的PDF
|
||
2. 用 parse_medical_data_v2 解析OCR文本
|
||
3. 用 classify_abb_module 对每个项目分类
|
||
4. 输出分类结果统计
|
||
"""
|
||
import sys
|
||
import os
|
||
import io
|
||
import json
|
||
import time
|
||
|
||
# 修复 Windows 终端 UTF-8
|
||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
|
||
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
from pathlib import Path
|
||
from dotenv import load_dotenv
|
||
load_dotenv(Path(__file__).parent / ".env")
|
||
|
||
from parse_medical_v2 import parse_medical_data_v2, clean_extracted_data_v2
|
||
from extract_and_fill_report import (
|
||
extract_pdf_text, classify_abb_module, match_with_template,
|
||
extract_patient_info
|
||
)
|
||
|
||
|
||
def main():
|
||
pdf_dir = Path(r"c:\Users\UI\Desktop\医疗报告\医疗报告智能体")
|
||
config_path = Path(__file__).parent / "abb_mapping_config.json"
|
||
|
||
pdf_files = list(pdf_dir.glob("*.pdf"))
|
||
if not pdf_files:
|
||
print("[ERROR] 没有找到PDF文件")
|
||
return
|
||
|
||
print("=" * 70)
|
||
print(" 百度OCR提取 + 分类测试")
|
||
print("=" * 70)
|
||
|
||
# ========== 步骤1: OCR提取 ==========
|
||
all_items = []
|
||
for pdf_file in pdf_files:
|
||
print(f"\n📄 OCR提取: {pdf_file.name} ({pdf_file.stat().st_size / 1024:.0f} KB)")
|
||
start = time.time()
|
||
text = extract_pdf_text(str(pdf_file))
|
||
elapsed = time.time() - start
|
||
lines = [l for l in text.split('\n') if l.strip()]
|
||
print(f" ✓ OCR完成 | 耗时: {elapsed:.1f}s | 行数: {len(lines)}")
|
||
|
||
# 保存OCR原文用于调试
|
||
ocr_output = Path(__file__).parent / "test_ocr_raw_text.txt"
|
||
with open(ocr_output, 'w', encoding='utf-8') as f:
|
||
f.write(text)
|
||
print(f" ✓ OCR原文已保存: {ocr_output.name}")
|
||
|
||
# 提取患者信息
|
||
patient_info = extract_patient_info(text)
|
||
print(f"\n 患者信息:")
|
||
print(f" 姓名: {patient_info.get('name', '未提取')}")
|
||
print(f" 性别: {patient_info.get('gender', '未提取')}")
|
||
print(f" 年龄: {patient_info.get('age', '未提取')}")
|
||
|
||
# 解析检测项
|
||
items = parse_medical_data_v2(text, pdf_file.name)
|
||
items = clean_extracted_data_v2(items)
|
||
print(f"\n ✓ 解析出 {len(items)} 个检测项")
|
||
all_items.extend(items)
|
||
|
||
if not all_items:
|
||
print("\n[WARN] 未提取到任何检测项,OCR结果可能为非血液检测报告")
|
||
print(" 请检查 test_ocr_raw_text.txt 查看OCR原文")
|
||
return
|
||
|
||
# ========== 步骤2: 分类测试 ==========
|
||
print("\n" + "=" * 70)
|
||
print(f" 分类测试 ({len(all_items)} 个检测项)")
|
||
print("=" * 70)
|
||
|
||
# 按模块分组
|
||
by_module = {}
|
||
unclassified = []
|
||
|
||
for item in all_items:
|
||
abb = item.get('abb', '')
|
||
project = item.get('project', abb)
|
||
result = item.get('result', '')
|
||
module = classify_abb_module(abb, project, api_key=None)
|
||
|
||
item['classified_module'] = module
|
||
|
||
if module == 'Other':
|
||
unclassified.append(item)
|
||
else:
|
||
if module not in by_module:
|
||
by_module[module] = []
|
||
by_module[module].append(item)
|
||
|
||
# 打印每个模块的项目
|
||
print(f"\n 分类成功: {len(all_items) - len(unclassified)} 个")
|
||
print(f" 未分类(Other): {len(unclassified)} 个")
|
||
|
||
# 按模块显示
|
||
print("\n" + "-" * 70)
|
||
for module, items in sorted(by_module.items()):
|
||
print(f"\n 📁 [{module}] ({len(items)} 项)")
|
||
for item in items:
|
||
abb = item.get('abb', '?')
|
||
project = item.get('project', '')[:30]
|
||
result = item.get('result', '')[:15]
|
||
point = item.get('point', '')
|
||
print(f" {abb:<15} {project:<32} = {result:<15} {point}")
|
||
|
||
if unclassified:
|
||
print(f"\n ⚠️ [Other - 未分类] ({len(unclassified)} 项)")
|
||
for item in unclassified:
|
||
abb = item.get('abb', '?')
|
||
project = item.get('project', '')[:40]
|
||
result = item.get('result', '')[:15]
|
||
print(f" {abb:<15} {project:<42} = {result}")
|
||
|
||
# ========== 步骤3: 模板匹配测试 ==========
|
||
print("\n" + "=" * 70)
|
||
print(" 模板匹配测试")
|
||
print("=" * 70)
|
||
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
|
||
matched = match_with_template(all_items, config)
|
||
print(f" 模板匹配: {len(matched)} 个项目")
|
||
|
||
# 统计
|
||
module_count = {}
|
||
for abb, data in matched.items():
|
||
module = data.get('module', '')
|
||
if not module:
|
||
module = classify_abb_module(abb, data.get('project', abb), api_key=None)
|
||
if module not in module_count:
|
||
module_count[module] = 0
|
||
module_count[module] += 1
|
||
|
||
print("\n 模块分布:")
|
||
for module, count in sorted(module_count.items(), key=lambda x: -x[1]):
|
||
print(f" {module:<30} {count} 项")
|
||
|
||
# ========== 汇总 ==========
|
||
print("\n" + "=" * 70)
|
||
print(" 汇总")
|
||
print("=" * 70)
|
||
total = len(all_items)
|
||
classified = total - len(unclassified)
|
||
rate = classified / total * 100 if total else 0
|
||
print(f" 总提取项: {total}")
|
||
print(f" 分类成功: {classified} ({rate:.1f}%)")
|
||
print(f" 未分类: {len(unclassified)}")
|
||
print(f" 模块数: {len(by_module)}")
|
||
print("=" * 70)
|
||
|
||
# 保存结果
|
||
output_path = Path(__file__).parent / "test_ocr_classify_result.json"
|
||
save_data = {
|
||
"total_items": total,
|
||
"classified": classified,
|
||
"unclassified_count": len(unclassified),
|
||
"modules": {m: len(items) for m, items in by_module.items()},
|
||
"items": [{
|
||
"abb": item.get("abb", ""),
|
||
"project": item.get("project", ""),
|
||
"result": item.get("result", ""),
|
||
"point": item.get("point", ""),
|
||
"unit": item.get("unit", ""),
|
||
"module": item.get("classified_module", "Other")
|
||
} for item in all_items]
|
||
}
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(save_data, f, ensure_ascii=False, indent=2)
|
||
print(f"\n 结果已保存: {output_path.name}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|