初始化医疗报告生成项目,添加核心代码文件
This commit is contained in:
187
backend/test_ocr_and_classify.py
Normal file
187
backend/test_ocr_and_classify.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
实际OCR提取 + 分类测试
|
||||
流程:
|
||||
1. 调用百度OCR提取 医疗报告智能体 文件夹下的PDF
|
||||
2. 用 parse_medical_data_v2 解析OCR文本
|
||||
3. 用 classify_abb_module 对每个项目分类
|
||||
4. 输出分类结果统计
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
|
||||
# 修复 Windows 终端 UTF-8
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
|
||||
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(Path(__file__).parent / ".env")
|
||||
|
||||
from parse_medical_v2 import parse_medical_data_v2, clean_extracted_data_v2
|
||||
from extract_and_fill_report import (
|
||||
extract_pdf_text, classify_abb_module, match_with_template,
|
||||
extract_patient_info
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
pdf_dir = Path(r"c:\Users\UI\Desktop\医疗报告\医疗报告智能体")
|
||||
config_path = Path(__file__).parent / "abb_mapping_config.json"
|
||||
|
||||
pdf_files = list(pdf_dir.glob("*.pdf"))
|
||||
if not pdf_files:
|
||||
print("[ERROR] 没有找到PDF文件")
|
||||
return
|
||||
|
||||
print("=" * 70)
|
||||
print(" 百度OCR提取 + 分类测试")
|
||||
print("=" * 70)
|
||||
|
||||
# ========== 步骤1: OCR提取 ==========
|
||||
all_items = []
|
||||
for pdf_file in pdf_files:
|
||||
print(f"\n📄 OCR提取: {pdf_file.name} ({pdf_file.stat().st_size / 1024:.0f} KB)")
|
||||
start = time.time()
|
||||
text = extract_pdf_text(str(pdf_file))
|
||||
elapsed = time.time() - start
|
||||
lines = [l for l in text.split('\n') if l.strip()]
|
||||
print(f" ✓ OCR完成 | 耗时: {elapsed:.1f}s | 行数: {len(lines)}")
|
||||
|
||||
# 保存OCR原文用于调试
|
||||
ocr_output = Path(__file__).parent / "test_ocr_raw_text.txt"
|
||||
with open(ocr_output, 'w', encoding='utf-8') as f:
|
||||
f.write(text)
|
||||
print(f" ✓ OCR原文已保存: {ocr_output.name}")
|
||||
|
||||
# 提取患者信息
|
||||
patient_info = extract_patient_info(text)
|
||||
print(f"\n 患者信息:")
|
||||
print(f" 姓名: {patient_info.get('name', '未提取')}")
|
||||
print(f" 性别: {patient_info.get('gender', '未提取')}")
|
||||
print(f" 年龄: {patient_info.get('age', '未提取')}")
|
||||
|
||||
# 解析检测项
|
||||
items = parse_medical_data_v2(text, pdf_file.name)
|
||||
items = clean_extracted_data_v2(items)
|
||||
print(f"\n ✓ 解析出 {len(items)} 个检测项")
|
||||
all_items.extend(items)
|
||||
|
||||
if not all_items:
|
||||
print("\n[WARN] 未提取到任何检测项,OCR结果可能为非血液检测报告")
|
||||
print(" 请检查 test_ocr_raw_text.txt 查看OCR原文")
|
||||
return
|
||||
|
||||
# ========== 步骤2: 分类测试 ==========
|
||||
print("\n" + "=" * 70)
|
||||
print(f" 分类测试 ({len(all_items)} 个检测项)")
|
||||
print("=" * 70)
|
||||
|
||||
# 按模块分组
|
||||
by_module = {}
|
||||
unclassified = []
|
||||
|
||||
for item in all_items:
|
||||
abb = item.get('abb', '')
|
||||
project = item.get('project', abb)
|
||||
result = item.get('result', '')
|
||||
module = classify_abb_module(abb, project, api_key=None)
|
||||
|
||||
item['classified_module'] = module
|
||||
|
||||
if module == 'Other':
|
||||
unclassified.append(item)
|
||||
else:
|
||||
if module not in by_module:
|
||||
by_module[module] = []
|
||||
by_module[module].append(item)
|
||||
|
||||
# 打印每个模块的项目
|
||||
print(f"\n 分类成功: {len(all_items) - len(unclassified)} 个")
|
||||
print(f" 未分类(Other): {len(unclassified)} 个")
|
||||
|
||||
# 按模块显示
|
||||
print("\n" + "-" * 70)
|
||||
for module, items in sorted(by_module.items()):
|
||||
print(f"\n 📁 [{module}] ({len(items)} 项)")
|
||||
for item in items:
|
||||
abb = item.get('abb', '?')
|
||||
project = item.get('project', '')[:30]
|
||||
result = item.get('result', '')[:15]
|
||||
point = item.get('point', '')
|
||||
print(f" {abb:<15} {project:<32} = {result:<15} {point}")
|
||||
|
||||
if unclassified:
|
||||
print(f"\n ⚠️ [Other - 未分类] ({len(unclassified)} 项)")
|
||||
for item in unclassified:
|
||||
abb = item.get('abb', '?')
|
||||
project = item.get('project', '')[:40]
|
||||
result = item.get('result', '')[:15]
|
||||
print(f" {abb:<15} {project:<42} = {result}")
|
||||
|
||||
# ========== 步骤3: 模板匹配测试 ==========
|
||||
print("\n" + "=" * 70)
|
||||
print(" 模板匹配测试")
|
||||
print("=" * 70)
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
matched = match_with_template(all_items, config)
|
||||
print(f" 模板匹配: {len(matched)} 个项目")
|
||||
|
||||
# 统计
|
||||
module_count = {}
|
||||
for abb, data in matched.items():
|
||||
module = data.get('module', '')
|
||||
if not module:
|
||||
module = classify_abb_module(abb, data.get('project', abb), api_key=None)
|
||||
if module not in module_count:
|
||||
module_count[module] = 0
|
||||
module_count[module] += 1
|
||||
|
||||
print("\n 模块分布:")
|
||||
for module, count in sorted(module_count.items(), key=lambda x: -x[1]):
|
||||
print(f" {module:<30} {count} 项")
|
||||
|
||||
# ========== 汇总 ==========
|
||||
print("\n" + "=" * 70)
|
||||
print(" 汇总")
|
||||
print("=" * 70)
|
||||
total = len(all_items)
|
||||
classified = total - len(unclassified)
|
||||
rate = classified / total * 100 if total else 0
|
||||
print(f" 总提取项: {total}")
|
||||
print(f" 分类成功: {classified} ({rate:.1f}%)")
|
||||
print(f" 未分类: {len(unclassified)}")
|
||||
print(f" 模块数: {len(by_module)}")
|
||||
print("=" * 70)
|
||||
|
||||
# 保存结果
|
||||
output_path = Path(__file__).parent / "test_ocr_classify_result.json"
|
||||
save_data = {
|
||||
"total_items": total,
|
||||
"classified": classified,
|
||||
"unclassified_count": len(unclassified),
|
||||
"modules": {m: len(items) for m, items in by_module.items()},
|
||||
"items": [{
|
||||
"abb": item.get("abb", ""),
|
||||
"project": item.get("project", ""),
|
||||
"result": item.get("result", ""),
|
||||
"point": item.get("point", ""),
|
||||
"unit": item.get("unit", ""),
|
||||
"module": item.get("classified_module", "Other")
|
||||
} for item in all_items]
|
||||
}
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(save_data, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n 结果已保存: {output_path.name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user