Files
yiliao/backend/test_ocr_and_classify.py

188 lines
6.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
实际OCR提取 + 分类测试
流程:
1. 调用百度OCR提取 医疗报告智能体 文件夹下的PDF
2. 用 parse_medical_data_v2 解析OCR文本
3. 用 classify_abb_module 对每个项目分类
4. 输出分类结果统计
"""
import sys
import os
import io
import json
import time
# 修复 Windows 终端 UTF-8
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from pathlib import Path
from dotenv import load_dotenv
load_dotenv(Path(__file__).parent / ".env")
from parse_medical_v2 import parse_medical_data_v2, clean_extracted_data_v2
from extract_and_fill_report import (
extract_pdf_text, classify_abb_module, match_with_template,
extract_patient_info
)
def main():
pdf_dir = Path(r"c:\Users\UI\Desktop\医疗报告\医疗报告智能体")
config_path = Path(__file__).parent / "abb_mapping_config.json"
pdf_files = list(pdf_dir.glob("*.pdf"))
if not pdf_files:
print("[ERROR] 没有找到PDF文件")
return
print("=" * 70)
print(" 百度OCR提取 + 分类测试")
print("=" * 70)
# ========== 步骤1: OCR提取 ==========
all_items = []
for pdf_file in pdf_files:
print(f"\n📄 OCR提取: {pdf_file.name} ({pdf_file.stat().st_size / 1024:.0f} KB)")
start = time.time()
text = extract_pdf_text(str(pdf_file))
elapsed = time.time() - start
lines = [l for l in text.split('\n') if l.strip()]
print(f" ✓ OCR完成 | 耗时: {elapsed:.1f}s | 行数: {len(lines)}")
# 保存OCR原文用于调试
ocr_output = Path(__file__).parent / "test_ocr_raw_text.txt"
with open(ocr_output, 'w', encoding='utf-8') as f:
f.write(text)
print(f" ✓ OCR原文已保存: {ocr_output.name}")
# 提取患者信息
patient_info = extract_patient_info(text)
print(f"\n 患者信息:")
print(f" 姓名: {patient_info.get('name', '未提取')}")
print(f" 性别: {patient_info.get('gender', '未提取')}")
print(f" 年龄: {patient_info.get('age', '未提取')}")
# 解析检测项
items = parse_medical_data_v2(text, pdf_file.name)
items = clean_extracted_data_v2(items)
print(f"\n ✓ 解析出 {len(items)} 个检测项")
all_items.extend(items)
if not all_items:
print("\n[WARN] 未提取到任何检测项OCR结果可能为非血液检测报告")
print(" 请检查 test_ocr_raw_text.txt 查看OCR原文")
return
# ========== 步骤2: 分类测试 ==========
print("\n" + "=" * 70)
print(f" 分类测试 ({len(all_items)} 个检测项)")
print("=" * 70)
# 按模块分组
by_module = {}
unclassified = []
for item in all_items:
abb = item.get('abb', '')
project = item.get('project', abb)
result = item.get('result', '')
module = classify_abb_module(abb, project, api_key=None)
item['classified_module'] = module
if module == 'Other':
unclassified.append(item)
else:
if module not in by_module:
by_module[module] = []
by_module[module].append(item)
# 打印每个模块的项目
print(f"\n 分类成功: {len(all_items) - len(unclassified)}")
print(f" 未分类(Other): {len(unclassified)}")
# 按模块显示
print("\n" + "-" * 70)
for module, items in sorted(by_module.items()):
print(f"\n 📁 [{module}] ({len(items)} 项)")
for item in items:
abb = item.get('abb', '?')
project = item.get('project', '')[:30]
result = item.get('result', '')[:15]
point = item.get('point', '')
print(f" {abb:<15} {project:<32} = {result:<15} {point}")
if unclassified:
print(f"\n ⚠️ [Other - 未分类] ({len(unclassified)} 项)")
for item in unclassified:
abb = item.get('abb', '?')
project = item.get('project', '')[:40]
result = item.get('result', '')[:15]
print(f" {abb:<15} {project:<42} = {result}")
# ========== 步骤3: 模板匹配测试 ==========
print("\n" + "=" * 70)
print(" 模板匹配测试")
print("=" * 70)
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
matched = match_with_template(all_items, config)
print(f" 模板匹配: {len(matched)} 个项目")
# 统计
module_count = {}
for abb, data in matched.items():
module = data.get('module', '')
if not module:
module = classify_abb_module(abb, data.get('project', abb), api_key=None)
if module not in module_count:
module_count[module] = 0
module_count[module] += 1
print("\n 模块分布:")
for module, count in sorted(module_count.items(), key=lambda x: -x[1]):
print(f" {module:<30} {count}")
# ========== 汇总 ==========
print("\n" + "=" * 70)
print(" 汇总")
print("=" * 70)
total = len(all_items)
classified = total - len(unclassified)
rate = classified / total * 100 if total else 0
print(f" 总提取项: {total}")
print(f" 分类成功: {classified} ({rate:.1f}%)")
print(f" 未分类: {len(unclassified)}")
print(f" 模块数: {len(by_module)}")
print("=" * 70)
# 保存结果
output_path = Path(__file__).parent / "test_ocr_classify_result.json"
save_data = {
"total_items": total,
"classified": classified,
"unclassified_count": len(unclassified),
"modules": {m: len(items) for m, items in by_module.items()},
"items": [{
"abb": item.get("abb", ""),
"project": item.get("project", ""),
"result": item.get("result", ""),
"point": item.get("point", ""),
"unit": item.get("unit", ""),
"module": item.get("classified_module", "Other")
} for item in all_items]
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(save_data, f, ensure_ascii=False, indent=2)
print(f"\n 结果已保存: {output_path.name}")
if __name__ == "__main__":
main()