初始化医疗报告生成项目,添加核心代码文件
This commit is contained in:
370
backend/test_baidu_ocr.py
Normal file
370
backend/test_baidu_ocr.py
Normal file
@@ -0,0 +1,370 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
百度OCR识别测试脚本
|
||||
测试目标:使用百度OCR对产品报价表图片进行识别,验证识别效果
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# 修复 Windows 终端 UTF-8 输出
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
|
||||
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
|
||||
|
||||
# 加载环境变量
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(Path(__file__).parent / ".env")
|
||||
|
||||
from aip import AipOcr
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
APP_ID = os.getenv("BAIDU_OCR_APP_ID", "")
|
||||
API_KEY = os.getenv("BAIDU_OCR_API_KEY", "")
|
||||
SECRET_KEY = os.getenv("BAIDU_OCR_SECRET_KEY", "")
|
||||
|
||||
# 测试图片路径
|
||||
IMAGE_PATH = r"C:\Users\UI\.cursor\projects\c-Users-UI-Desktop\assets\c__Users_UI_AppData_Roaming_Cursor_User_workspaceStorage_6df83b93d4a0651428307542725e79d8_images_ecdbe509-3f63-49c0-a8be-db9facaef857_3_-4dec6c0d-a755-4bda-8780-9e6b20e02df8.png"
|
||||
|
||||
|
||||
def test_accurate_basic(client, image_data):
|
||||
"""测试1:通用文字识别(高精度版)- basicAccurate"""
|
||||
print("\n" + "=" * 70)
|
||||
print("[测试1] 通用文字识别(高精度版)- basicAccurate")
|
||||
print("=" * 70)
|
||||
|
||||
start = time.time()
|
||||
result = client.basicAccurate(image_data)
|
||||
elapsed = time.time() - start
|
||||
|
||||
if "error_code" in result:
|
||||
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
|
||||
return None
|
||||
|
||||
words_result = result.get("words_result", [])
|
||||
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
|
||||
print("-" * 70)
|
||||
for i, item in enumerate(words_result):
|
||||
print(f" [{i+1:3d}] {item['words']}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def test_accurate(client, image_data):
|
||||
"""测试2:通用文字识别(高精度含位置版)- accurate"""
|
||||
print("\n" + "=" * 70)
|
||||
print("[测试2] 通用文字识别(高精度含位置版)- accurate")
|
||||
print("=" * 70)
|
||||
|
||||
start = time.time()
|
||||
result = client.accurate(image_data)
|
||||
elapsed = time.time() - start
|
||||
|
||||
if "error_code" in result:
|
||||
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
|
||||
return None
|
||||
|
||||
words_result = result.get("words_result", [])
|
||||
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
|
||||
print("-" * 70)
|
||||
for i, item in enumerate(words_result):
|
||||
loc = item.get("location", {})
|
||||
pos_str = f"(x={loc.get('left',0)}, y={loc.get('top',0)}, w={loc.get('width',0)}, h={loc.get('height',0)})"
|
||||
print(f" [{i+1:3d}] {pos_str:40s} {item['words']}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def test_general_basic(client, image_data):
|
||||
"""测试3:通用文字识别(标准版)- basicGeneral"""
|
||||
print("\n" + "=" * 70)
|
||||
print("[测试3] 通用文字识别(标准版)- basicGeneral")
|
||||
print("=" * 70)
|
||||
|
||||
start = time.time()
|
||||
result = client.basicGeneral(image_data)
|
||||
elapsed = time.time() - start
|
||||
|
||||
if "error_code" in result:
|
||||
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
|
||||
return None
|
||||
|
||||
words_result = result.get("words_result", [])
|
||||
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
|
||||
print("-" * 70)
|
||||
for i, item in enumerate(words_result):
|
||||
print(f" [{i+1:3d}] {item['words']}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def test_table_recognize(client, image_data):
|
||||
"""测试4:表格文字识别 - tableRecognition (异步)"""
|
||||
print("\n" + "=" * 70)
|
||||
print("[测试4] 表格文字识别 - tableRecognitionAsync")
|
||||
print("=" * 70)
|
||||
|
||||
# 提交表格识别请求
|
||||
start = time.time()
|
||||
result = client.tableRecognitionAsync(image_data)
|
||||
|
||||
if "error_code" in result:
|
||||
print(f" [FAIL] 提交失败 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
|
||||
return None
|
||||
|
||||
# tableRecognitionAsync 返回格式可能不同,兼容处理
|
||||
result_list = result.get("result", [])
|
||||
if isinstance(result_list, list) and len(result_list) > 0:
|
||||
request_id = result_list[0].get("request_id", "")
|
||||
elif isinstance(result_list, dict):
|
||||
request_id = result_list.get("request_id", "")
|
||||
else:
|
||||
request_id = ""
|
||||
|
||||
if not request_id:
|
||||
print(f" [FAIL] 未获取到 request_id,返回结果: {json.dumps(result, ensure_ascii=False)}")
|
||||
return None
|
||||
|
||||
print(f" [INFO] 提交成功 | request_id: {request_id}")
|
||||
print(" [INFO] 等待识别结果...")
|
||||
|
||||
# 轮询获取结果(最多等60秒)
|
||||
ret_code = -1
|
||||
for attempt in range(20):
|
||||
time.sleep(3)
|
||||
get_result = client.getTableRecognitionResult(request_id)
|
||||
|
||||
if "error_code" in get_result:
|
||||
print(f" [FAIL] 查询失败 ({get_result['error_code']}): {get_result.get('error_msg', '')}")
|
||||
return None
|
||||
|
||||
percent = get_result.get("result", {}).get("percent", 0)
|
||||
ret_code = get_result.get("result", {}).get("ret_code", -1)
|
||||
|
||||
if ret_code == 3:
|
||||
# 识别完成
|
||||
elapsed = time.time() - start
|
||||
print(f" [OK] 识别完成 | 耗时: {elapsed:.2f}s")
|
||||
|
||||
# 解析表格结果
|
||||
result_data = get_result.get("result", {}).get("result_data", "")
|
||||
if result_data:
|
||||
print("-" * 70)
|
||||
print(" 表格识别结果(原始):")
|
||||
try:
|
||||
table_data = json.loads(result_data)
|
||||
formatted = json.dumps(table_data, ensure_ascii=False, indent=2)
|
||||
print(formatted[:5000])
|
||||
if len(formatted) > 5000:
|
||||
print(" ... (结果过长,已截断)")
|
||||
except Exception:
|
||||
print(result_data[:5000])
|
||||
|
||||
return get_result
|
||||
|
||||
print(f" 轮询 {attempt+1}/20 | 进度: {percent}%")
|
||||
|
||||
elapsed = time.time() - start
|
||||
print(f" [WARN] 超时(等待 {elapsed:.1f}s),最后状态: ret_code={ret_code}")
|
||||
return None
|
||||
|
||||
|
||||
def test_web_image(client, image_data):
|
||||
"""测试5:网络图片文字识别 - webImage"""
|
||||
print("\n" + "=" * 70)
|
||||
print("[测试5] 网络图片文字识别 - webImage")
|
||||
print("=" * 70)
|
||||
|
||||
start = time.time()
|
||||
result = client.webImage(image_data)
|
||||
elapsed = time.time() - start
|
||||
|
||||
if "error_code" in result:
|
||||
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
|
||||
return None
|
||||
|
||||
words_result = result.get("words_result", [])
|
||||
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
|
||||
print("-" * 70)
|
||||
for i, item in enumerate(words_result):
|
||||
print(f" [{i+1:3d}] {item['words']}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def test_table_sync(client, image_data):
|
||||
"""测试6:表格识别(同步版)- form"""
|
||||
print("\n" + "=" * 70)
|
||||
print("[测试6] 表格识别(同步版)- form")
|
||||
print("=" * 70)
|
||||
|
||||
start = time.time()
|
||||
result = client.form(image_data)
|
||||
elapsed = time.time() - start
|
||||
|
||||
if "error_code" in result:
|
||||
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
|
||||
return None
|
||||
|
||||
forms_result = result.get("forms_result", [])
|
||||
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 表单数: {len(forms_result)}")
|
||||
print("-" * 70)
|
||||
|
||||
# 打印表格内容
|
||||
for f_idx, form in enumerate(forms_result):
|
||||
print(f"\n === 表单 {f_idx + 1} ===")
|
||||
header = form.get("header", [])
|
||||
body = form.get("body", [])
|
||||
footer = form.get("footer", [])
|
||||
|
||||
if header:
|
||||
print(" [表头]")
|
||||
for row in header:
|
||||
if isinstance(row, dict):
|
||||
print(f" {row.get('words', row)}")
|
||||
elif isinstance(row, list):
|
||||
row_text = " | ".join(
|
||||
cell.get("words", str(cell)) if isinstance(cell, dict) else str(cell)
|
||||
for cell in row
|
||||
)
|
||||
print(f" {row_text}")
|
||||
|
||||
if body:
|
||||
print(" [表体]")
|
||||
for r_idx, row in enumerate(body[:80]):
|
||||
if isinstance(row, dict):
|
||||
print(f" {row.get('words', row)}")
|
||||
elif isinstance(row, list):
|
||||
row_text = " | ".join(
|
||||
cell.get("words", str(cell)) if isinstance(cell, dict) else str(cell)
|
||||
for cell in row
|
||||
)
|
||||
print(f" {row_text}")
|
||||
if len(body) > 80:
|
||||
print(f" ... (共 {len(body)} 行)")
|
||||
|
||||
# 如果 forms_result 为空,打印原始结果
|
||||
if not forms_result:
|
||||
print(f" 原始结果键: {list(result.keys())}")
|
||||
formatted = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
print(formatted[:3000])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def save_results(results, output_path):
|
||||
"""保存识别结果到JSON文件"""
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n[SAVE] 结果已保存到: {output_path}")
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 70)
|
||||
print("百度OCR识别测试 - 产品报价表图片")
|
||||
print("=" * 70)
|
||||
|
||||
# 检查配置
|
||||
if not all([APP_ID, API_KEY, SECRET_KEY]):
|
||||
print("[FAIL] 百度OCR未配置,请检查 .env 文件中的 BAIDU_OCR_* 变量")
|
||||
sys.exit(1)
|
||||
|
||||
print(f" APP_ID: {APP_ID}")
|
||||
print(f" API_KEY: {API_KEY[:8]}...")
|
||||
print(f" SECRET_KEY: {SECRET_KEY[:8]}...")
|
||||
|
||||
# 检查图片文件
|
||||
if not Path(IMAGE_PATH).exists():
|
||||
print(f"[FAIL] 图片文件不存在: {IMAGE_PATH}")
|
||||
sys.exit(1)
|
||||
|
||||
file_size = Path(IMAGE_PATH).stat().st_size
|
||||
print(f" 图片路径: {IMAGE_PATH}")
|
||||
print(f" 文件大小: {file_size / 1024:.1f} KB")
|
||||
|
||||
# 初始化百度OCR客户端
|
||||
client = AipOcr(APP_ID, API_KEY, SECRET_KEY)
|
||||
|
||||
# 读取图片
|
||||
with open(IMAGE_PATH, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
print(f" 图片数据: {len(image_data)} bytes")
|
||||
|
||||
# 收集所有测试结果
|
||||
all_results = {}
|
||||
|
||||
# ---- 测试1:高精度版 ----
|
||||
r1 = test_accurate_basic(client, image_data)
|
||||
if r1:
|
||||
all_results["accurate_basic"] = {
|
||||
"method": "basicAccurate(高精度版)",
|
||||
"lines": len(r1.get("words_result", [])),
|
||||
"data": r1,
|
||||
}
|
||||
|
||||
# ---- 测试2:高精度含位置版 ----
|
||||
r2 = test_accurate(client, image_data)
|
||||
if r2:
|
||||
all_results["accurate_with_location"] = {
|
||||
"method": "accurate(高精度含位置版)",
|
||||
"lines": len(r2.get("words_result", [])),
|
||||
"data": r2,
|
||||
}
|
||||
|
||||
# ---- 测试3:标准版 ----
|
||||
r3 = test_general_basic(client, image_data)
|
||||
if r3:
|
||||
all_results["general_basic"] = {
|
||||
"method": "basicGeneral(标准版)",
|
||||
"lines": len(r3.get("words_result", [])),
|
||||
"data": r3,
|
||||
}
|
||||
|
||||
# ---- 测试4:表格识别(异步) ----
|
||||
r4 = test_table_recognize(client, image_data)
|
||||
if r4:
|
||||
all_results["table_recognition_async"] = {
|
||||
"method": "tableRecognitionAsync(表格识别-异步)",
|
||||
"data": r4,
|
||||
}
|
||||
|
||||
# ---- 测试5:网络图片文字识别 ----
|
||||
r5 = test_web_image(client, image_data)
|
||||
if r5:
|
||||
all_results["web_image"] = {
|
||||
"method": "webImage(网络图片文字识别)",
|
||||
"lines": len(r5.get("words_result", [])),
|
||||
"data": r5,
|
||||
}
|
||||
|
||||
# ---- 测试6:表格识别(同步) ----
|
||||
r6 = test_table_sync(client, image_data)
|
||||
if r6:
|
||||
all_results["table_sync"] = {
|
||||
"method": "form(表格识别-同步)",
|
||||
"data": r6,
|
||||
}
|
||||
|
||||
# ---- 汇总 ----
|
||||
print("\n" + "=" * 70)
|
||||
print("测试汇总")
|
||||
print("=" * 70)
|
||||
for key, val in all_results.items():
|
||||
lines = val.get("lines", "N/A")
|
||||
print(f" {val['method']:45s} 识别行数: {lines}")
|
||||
|
||||
# 保存结果
|
||||
output_path = Path(__file__).parent / "test_baidu_ocr_results.json"
|
||||
save_results(all_results, output_path)
|
||||
|
||||
print("\n[DONE] 所有测试完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user