Files
yiliao/backend/test_baidu_ocr.py

371 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
百度OCR识别测试脚本
测试目标使用百度OCR对产品报价表图片进行识别验证识别效果
"""
import os
import sys
import io
import json
import time
from pathlib import Path
# 修复 Windows 终端 UTF-8 输出
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
# 加载环境变量
from dotenv import load_dotenv
load_dotenv(Path(__file__).parent / ".env")
from aip import AipOcr
# ==================== 配置 ====================
APP_ID = os.getenv("BAIDU_OCR_APP_ID", "")
API_KEY = os.getenv("BAIDU_OCR_API_KEY", "")
SECRET_KEY = os.getenv("BAIDU_OCR_SECRET_KEY", "")
# 测试图片路径
IMAGE_PATH = r"C:\Users\UI\.cursor\projects\c-Users-UI-Desktop\assets\c__Users_UI_AppData_Roaming_Cursor_User_workspaceStorage_6df83b93d4a0651428307542725e79d8_images_ecdbe509-3f63-49c0-a8be-db9facaef857_3_-4dec6c0d-a755-4bda-8780-9e6b20e02df8.png"
def test_accurate_basic(client, image_data):
"""测试1通用文字识别高精度版- basicAccurate"""
print("\n" + "=" * 70)
print("[测试1] 通用文字识别(高精度版)- basicAccurate")
print("=" * 70)
start = time.time()
result = client.basicAccurate(image_data)
elapsed = time.time() - start
if "error_code" in result:
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
return None
words_result = result.get("words_result", [])
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
print("-" * 70)
for i, item in enumerate(words_result):
print(f" [{i+1:3d}] {item['words']}")
return result
def test_accurate(client, image_data):
"""测试2通用文字识别高精度含位置版- accurate"""
print("\n" + "=" * 70)
print("[测试2] 通用文字识别(高精度含位置版)- accurate")
print("=" * 70)
start = time.time()
result = client.accurate(image_data)
elapsed = time.time() - start
if "error_code" in result:
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
return None
words_result = result.get("words_result", [])
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
print("-" * 70)
for i, item in enumerate(words_result):
loc = item.get("location", {})
pos_str = f"(x={loc.get('left',0)}, y={loc.get('top',0)}, w={loc.get('width',0)}, h={loc.get('height',0)})"
print(f" [{i+1:3d}] {pos_str:40s} {item['words']}")
return result
def test_general_basic(client, image_data):
"""测试3通用文字识别标准版- basicGeneral"""
print("\n" + "=" * 70)
print("[测试3] 通用文字识别(标准版)- basicGeneral")
print("=" * 70)
start = time.time()
result = client.basicGeneral(image_data)
elapsed = time.time() - start
if "error_code" in result:
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
return None
words_result = result.get("words_result", [])
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
print("-" * 70)
for i, item in enumerate(words_result):
print(f" [{i+1:3d}] {item['words']}")
return result
def test_table_recognize(client, image_data):
"""测试4表格文字识别 - tableRecognition (异步)"""
print("\n" + "=" * 70)
print("[测试4] 表格文字识别 - tableRecognitionAsync")
print("=" * 70)
# 提交表格识别请求
start = time.time()
result = client.tableRecognitionAsync(image_data)
if "error_code" in result:
print(f" [FAIL] 提交失败 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
return None
# tableRecognitionAsync 返回格式可能不同,兼容处理
result_list = result.get("result", [])
if isinstance(result_list, list) and len(result_list) > 0:
request_id = result_list[0].get("request_id", "")
elif isinstance(result_list, dict):
request_id = result_list.get("request_id", "")
else:
request_id = ""
if not request_id:
print(f" [FAIL] 未获取到 request_id返回结果: {json.dumps(result, ensure_ascii=False)}")
return None
print(f" [INFO] 提交成功 | request_id: {request_id}")
print(" [INFO] 等待识别结果...")
# 轮询获取结果最多等60秒
ret_code = -1
for attempt in range(20):
time.sleep(3)
get_result = client.getTableRecognitionResult(request_id)
if "error_code" in get_result:
print(f" [FAIL] 查询失败 ({get_result['error_code']}): {get_result.get('error_msg', '')}")
return None
percent = get_result.get("result", {}).get("percent", 0)
ret_code = get_result.get("result", {}).get("ret_code", -1)
if ret_code == 3:
# 识别完成
elapsed = time.time() - start
print(f" [OK] 识别完成 | 耗时: {elapsed:.2f}s")
# 解析表格结果
result_data = get_result.get("result", {}).get("result_data", "")
if result_data:
print("-" * 70)
print(" 表格识别结果(原始):")
try:
table_data = json.loads(result_data)
formatted = json.dumps(table_data, ensure_ascii=False, indent=2)
print(formatted[:5000])
if len(formatted) > 5000:
print(" ... (结果过长,已截断)")
except Exception:
print(result_data[:5000])
return get_result
print(f" 轮询 {attempt+1}/20 | 进度: {percent}%")
elapsed = time.time() - start
print(f" [WARN] 超时(等待 {elapsed:.1f}s最后状态: ret_code={ret_code}")
return None
def test_web_image(client, image_data):
"""测试5网络图片文字识别 - webImage"""
print("\n" + "=" * 70)
print("[测试5] 网络图片文字识别 - webImage")
print("=" * 70)
start = time.time()
result = client.webImage(image_data)
elapsed = time.time() - start
if "error_code" in result:
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
return None
words_result = result.get("words_result", [])
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
print("-" * 70)
for i, item in enumerate(words_result):
print(f" [{i+1:3d}] {item['words']}")
return result
def test_table_sync(client, image_data):
"""测试6表格识别同步版- form"""
print("\n" + "=" * 70)
print("[测试6] 表格识别(同步版)- form")
print("=" * 70)
start = time.time()
result = client.form(image_data)
elapsed = time.time() - start
if "error_code" in result:
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
return None
forms_result = result.get("forms_result", [])
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 表单数: {len(forms_result)}")
print("-" * 70)
# 打印表格内容
for f_idx, form in enumerate(forms_result):
print(f"\n === 表单 {f_idx + 1} ===")
header = form.get("header", [])
body = form.get("body", [])
footer = form.get("footer", [])
if header:
print(" [表头]")
for row in header:
if isinstance(row, dict):
print(f" {row.get('words', row)}")
elif isinstance(row, list):
row_text = " | ".join(
cell.get("words", str(cell)) if isinstance(cell, dict) else str(cell)
for cell in row
)
print(f" {row_text}")
if body:
print(" [表体]")
for r_idx, row in enumerate(body[:80]):
if isinstance(row, dict):
print(f" {row.get('words', row)}")
elif isinstance(row, list):
row_text = " | ".join(
cell.get("words", str(cell)) if isinstance(cell, dict) else str(cell)
for cell in row
)
print(f" {row_text}")
if len(body) > 80:
print(f" ... (共 {len(body)} 行)")
# 如果 forms_result 为空,打印原始结果
if not forms_result:
print(f" 原始结果键: {list(result.keys())}")
formatted = json.dumps(result, ensure_ascii=False, indent=2)
print(formatted[:3000])
return result
def save_results(results, output_path):
"""保存识别结果到JSON文件"""
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"\n[SAVE] 结果已保存到: {output_path}")
def main():
print("=" * 70)
print("百度OCR识别测试 - 产品报价表图片")
print("=" * 70)
# 检查配置
if not all([APP_ID, API_KEY, SECRET_KEY]):
print("[FAIL] 百度OCR未配置请检查 .env 文件中的 BAIDU_OCR_* 变量")
sys.exit(1)
print(f" APP_ID: {APP_ID}")
print(f" API_KEY: {API_KEY[:8]}...")
print(f" SECRET_KEY: {SECRET_KEY[:8]}...")
# 检查图片文件
if not Path(IMAGE_PATH).exists():
print(f"[FAIL] 图片文件不存在: {IMAGE_PATH}")
sys.exit(1)
file_size = Path(IMAGE_PATH).stat().st_size
print(f" 图片路径: {IMAGE_PATH}")
print(f" 文件大小: {file_size / 1024:.1f} KB")
# 初始化百度OCR客户端
client = AipOcr(APP_ID, API_KEY, SECRET_KEY)
# 读取图片
with open(IMAGE_PATH, "rb") as f:
image_data = f.read()
print(f" 图片数据: {len(image_data)} bytes")
# 收集所有测试结果
all_results = {}
# ---- 测试1高精度版 ----
r1 = test_accurate_basic(client, image_data)
if r1:
all_results["accurate_basic"] = {
"method": "basicAccurate高精度版",
"lines": len(r1.get("words_result", [])),
"data": r1,
}
# ---- 测试2高精度含位置版 ----
r2 = test_accurate(client, image_data)
if r2:
all_results["accurate_with_location"] = {
"method": "accurate高精度含位置版",
"lines": len(r2.get("words_result", [])),
"data": r2,
}
# ---- 测试3标准版 ----
r3 = test_general_basic(client, image_data)
if r3:
all_results["general_basic"] = {
"method": "basicGeneral标准版",
"lines": len(r3.get("words_result", [])),
"data": r3,
}
# ---- 测试4表格识别异步 ----
r4 = test_table_recognize(client, image_data)
if r4:
all_results["table_recognition_async"] = {
"method": "tableRecognitionAsync表格识别-异步)",
"data": r4,
}
# ---- 测试5网络图片文字识别 ----
r5 = test_web_image(client, image_data)
if r5:
all_results["web_image"] = {
"method": "webImage网络图片文字识别",
"lines": len(r5.get("words_result", [])),
"data": r5,
}
# ---- 测试6表格识别同步 ----
r6 = test_table_sync(client, image_data)
if r6:
all_results["table_sync"] = {
"method": "form表格识别-同步)",
"data": r6,
}
# ---- 汇总 ----
print("\n" + "=" * 70)
print("测试汇总")
print("=" * 70)
for key, val in all_results.items():
lines = val.get("lines", "N/A")
print(f" {val['method']:45s} 识别行数: {lines}")
# 保存结果
output_path = Path(__file__).parent / "test_baidu_ocr_results.json"
save_results(all_results, output_path)
print("\n[DONE] 所有测试完成!")
if __name__ == "__main__":
main()