Files
yiliao/backend/test_baidu_ocr.py

371 lines
12 KiB
Python
Raw Permalink Normal View History

# -*- coding: utf-8 -*-
"""
百度OCR识别测试脚本
测试目标使用百度OCR对产品报价表图片进行识别验证识别效果
"""
import os
import sys
import io
import json
import time
from pathlib import Path
# 修复 Windows 终端 UTF-8 输出
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
# 加载环境变量
from dotenv import load_dotenv
load_dotenv(Path(__file__).parent / ".env")
from aip import AipOcr
# ==================== 配置 ====================
APP_ID = os.getenv("BAIDU_OCR_APP_ID", "")
API_KEY = os.getenv("BAIDU_OCR_API_KEY", "")
SECRET_KEY = os.getenv("BAIDU_OCR_SECRET_KEY", "")
# 测试图片路径
IMAGE_PATH = r"C:\Users\UI\.cursor\projects\c-Users-UI-Desktop\assets\c__Users_UI_AppData_Roaming_Cursor_User_workspaceStorage_6df83b93d4a0651428307542725e79d8_images_ecdbe509-3f63-49c0-a8be-db9facaef857_3_-4dec6c0d-a755-4bda-8780-9e6b20e02df8.png"
def test_accurate_basic(client, image_data):
"""测试1通用文字识别高精度版- basicAccurate"""
print("\n" + "=" * 70)
print("[测试1] 通用文字识别(高精度版)- basicAccurate")
print("=" * 70)
start = time.time()
result = client.basicAccurate(image_data)
elapsed = time.time() - start
if "error_code" in result:
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
return None
words_result = result.get("words_result", [])
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
print("-" * 70)
for i, item in enumerate(words_result):
print(f" [{i+1:3d}] {item['words']}")
return result
def test_accurate(client, image_data):
"""测试2通用文字识别高精度含位置版- accurate"""
print("\n" + "=" * 70)
print("[测试2] 通用文字识别(高精度含位置版)- accurate")
print("=" * 70)
start = time.time()
result = client.accurate(image_data)
elapsed = time.time() - start
if "error_code" in result:
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
return None
words_result = result.get("words_result", [])
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
print("-" * 70)
for i, item in enumerate(words_result):
loc = item.get("location", {})
pos_str = f"(x={loc.get('left',0)}, y={loc.get('top',0)}, w={loc.get('width',0)}, h={loc.get('height',0)})"
print(f" [{i+1:3d}] {pos_str:40s} {item['words']}")
return result
def test_general_basic(client, image_data):
"""测试3通用文字识别标准版- basicGeneral"""
print("\n" + "=" * 70)
print("[测试3] 通用文字识别(标准版)- basicGeneral")
print("=" * 70)
start = time.time()
result = client.basicGeneral(image_data)
elapsed = time.time() - start
if "error_code" in result:
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
return None
words_result = result.get("words_result", [])
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
print("-" * 70)
for i, item in enumerate(words_result):
print(f" [{i+1:3d}] {item['words']}")
return result
def test_table_recognize(client, image_data):
"""测试4表格文字识别 - tableRecognition (异步)"""
print("\n" + "=" * 70)
print("[测试4] 表格文字识别 - tableRecognitionAsync")
print("=" * 70)
# 提交表格识别请求
start = time.time()
result = client.tableRecognitionAsync(image_data)
if "error_code" in result:
print(f" [FAIL] 提交失败 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
return None
# tableRecognitionAsync 返回格式可能不同,兼容处理
result_list = result.get("result", [])
if isinstance(result_list, list) and len(result_list) > 0:
request_id = result_list[0].get("request_id", "")
elif isinstance(result_list, dict):
request_id = result_list.get("request_id", "")
else:
request_id = ""
if not request_id:
print(f" [FAIL] 未获取到 request_id返回结果: {json.dumps(result, ensure_ascii=False)}")
return None
print(f" [INFO] 提交成功 | request_id: {request_id}")
print(" [INFO] 等待识别结果...")
# 轮询获取结果最多等60秒
ret_code = -1
for attempt in range(20):
time.sleep(3)
get_result = client.getTableRecognitionResult(request_id)
if "error_code" in get_result:
print(f" [FAIL] 查询失败 ({get_result['error_code']}): {get_result.get('error_msg', '')}")
return None
percent = get_result.get("result", {}).get("percent", 0)
ret_code = get_result.get("result", {}).get("ret_code", -1)
if ret_code == 3:
# 识别完成
elapsed = time.time() - start
print(f" [OK] 识别完成 | 耗时: {elapsed:.2f}s")
# 解析表格结果
result_data = get_result.get("result", {}).get("result_data", "")
if result_data:
print("-" * 70)
print(" 表格识别结果(原始):")
try:
table_data = json.loads(result_data)
formatted = json.dumps(table_data, ensure_ascii=False, indent=2)
print(formatted[:5000])
if len(formatted) > 5000:
print(" ... (结果过长,已截断)")
except Exception:
print(result_data[:5000])
return get_result
print(f" 轮询 {attempt+1}/20 | 进度: {percent}%")
elapsed = time.time() - start
print(f" [WARN] 超时(等待 {elapsed:.1f}s最后状态: ret_code={ret_code}")
return None
def test_web_image(client, image_data):
"""测试5网络图片文字识别 - webImage"""
print("\n" + "=" * 70)
print("[测试5] 网络图片文字识别 - webImage")
print("=" * 70)
start = time.time()
result = client.webImage(image_data)
elapsed = time.time() - start
if "error_code" in result:
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
return None
words_result = result.get("words_result", [])
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
print("-" * 70)
for i, item in enumerate(words_result):
print(f" [{i+1:3d}] {item['words']}")
return result
def test_table_sync(client, image_data):
"""测试6表格识别同步版- form"""
print("\n" + "=" * 70)
print("[测试6] 表格识别(同步版)- form")
print("=" * 70)
start = time.time()
result = client.form(image_data)
elapsed = time.time() - start
if "error_code" in result:
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
return None
forms_result = result.get("forms_result", [])
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 表单数: {len(forms_result)}")
print("-" * 70)
# 打印表格内容
for f_idx, form in enumerate(forms_result):
print(f"\n === 表单 {f_idx + 1} ===")
header = form.get("header", [])
body = form.get("body", [])
footer = form.get("footer", [])
if header:
print(" [表头]")
for row in header:
if isinstance(row, dict):
print(f" {row.get('words', row)}")
elif isinstance(row, list):
row_text = " | ".join(
cell.get("words", str(cell)) if isinstance(cell, dict) else str(cell)
for cell in row
)
print(f" {row_text}")
if body:
print(" [表体]")
for r_idx, row in enumerate(body[:80]):
if isinstance(row, dict):
print(f" {row.get('words', row)}")
elif isinstance(row, list):
row_text = " | ".join(
cell.get("words", str(cell)) if isinstance(cell, dict) else str(cell)
for cell in row
)
print(f" {row_text}")
if len(body) > 80:
print(f" ... (共 {len(body)} 行)")
# 如果 forms_result 为空,打印原始结果
if not forms_result:
print(f" 原始结果键: {list(result.keys())}")
formatted = json.dumps(result, ensure_ascii=False, indent=2)
print(formatted[:3000])
return result
def save_results(results, output_path):
"""保存识别结果到JSON文件"""
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"\n[SAVE] 结果已保存到: {output_path}")
def main():
print("=" * 70)
print("百度OCR识别测试 - 产品报价表图片")
print("=" * 70)
# 检查配置
if not all([APP_ID, API_KEY, SECRET_KEY]):
print("[FAIL] 百度OCR未配置请检查 .env 文件中的 BAIDU_OCR_* 变量")
sys.exit(1)
print(f" APP_ID: {APP_ID}")
print(f" API_KEY: {API_KEY[:8]}...")
print(f" SECRET_KEY: {SECRET_KEY[:8]}...")
# 检查图片文件
if not Path(IMAGE_PATH).exists():
print(f"[FAIL] 图片文件不存在: {IMAGE_PATH}")
sys.exit(1)
file_size = Path(IMAGE_PATH).stat().st_size
print(f" 图片路径: {IMAGE_PATH}")
print(f" 文件大小: {file_size / 1024:.1f} KB")
# 初始化百度OCR客户端
client = AipOcr(APP_ID, API_KEY, SECRET_KEY)
# 读取图片
with open(IMAGE_PATH, "rb") as f:
image_data = f.read()
print(f" 图片数据: {len(image_data)} bytes")
# 收集所有测试结果
all_results = {}
# ---- 测试1高精度版 ----
r1 = test_accurate_basic(client, image_data)
if r1:
all_results["accurate_basic"] = {
"method": "basicAccurate高精度版",
"lines": len(r1.get("words_result", [])),
"data": r1,
}
# ---- 测试2高精度含位置版 ----
r2 = test_accurate(client, image_data)
if r2:
all_results["accurate_with_location"] = {
"method": "accurate高精度含位置版",
"lines": len(r2.get("words_result", [])),
"data": r2,
}
# ---- 测试3标准版 ----
r3 = test_general_basic(client, image_data)
if r3:
all_results["general_basic"] = {
"method": "basicGeneral标准版",
"lines": len(r3.get("words_result", [])),
"data": r3,
}
# ---- 测试4表格识别异步 ----
r4 = test_table_recognize(client, image_data)
if r4:
all_results["table_recognition_async"] = {
"method": "tableRecognitionAsync表格识别-异步)",
"data": r4,
}
# ---- 测试5网络图片文字识别 ----
r5 = test_web_image(client, image_data)
if r5:
all_results["web_image"] = {
"method": "webImage网络图片文字识别",
"lines": len(r5.get("words_result", [])),
"data": r5,
}
# ---- 测试6表格识别同步 ----
r6 = test_table_sync(client, image_data)
if r6:
all_results["table_sync"] = {
"method": "form表格识别-同步)",
"data": r6,
}
# ---- 汇总 ----
print("\n" + "=" * 70)
print("测试汇总")
print("=" * 70)
for key, val in all_results.items():
lines = val.get("lines", "N/A")
print(f" {val['method']:45s} 识别行数: {lines}")
# 保存结果
output_path = Path(__file__).parent / "test_baidu_ocr_results.json"
save_results(all_results, output_path)
print("\n[DONE] 所有测试完成!")
if __name__ == "__main__":
main()