371 lines
12 KiB
Python
371 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
百度OCR识别测试脚本
|
||
测试目标:使用百度OCR对产品报价表图片进行识别,验证识别效果
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import io
|
||
import json
|
||
import time
|
||
from pathlib import Path
|
||
|
||
# 修复 Windows 终端 UTF-8 输出
|
||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
|
||
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
|
||
|
||
# 加载环境变量
|
||
from dotenv import load_dotenv
|
||
load_dotenv(Path(__file__).parent / ".env")
|
||
|
||
from aip import AipOcr
|
||
|
||
|
||
# ==================== 配置 ====================
|
||
|
||
APP_ID = os.getenv("BAIDU_OCR_APP_ID", "")
|
||
API_KEY = os.getenv("BAIDU_OCR_API_KEY", "")
|
||
SECRET_KEY = os.getenv("BAIDU_OCR_SECRET_KEY", "")
|
||
|
||
# 测试图片路径
|
||
IMAGE_PATH = r"C:\Users\UI\.cursor\projects\c-Users-UI-Desktop\assets\c__Users_UI_AppData_Roaming_Cursor_User_workspaceStorage_6df83b93d4a0651428307542725e79d8_images_ecdbe509-3f63-49c0-a8be-db9facaef857_3_-4dec6c0d-a755-4bda-8780-9e6b20e02df8.png"
|
||
|
||
|
||
def test_accurate_basic(client, image_data):
|
||
"""测试1:通用文字识别(高精度版)- basicAccurate"""
|
||
print("\n" + "=" * 70)
|
||
print("[测试1] 通用文字识别(高精度版)- basicAccurate")
|
||
print("=" * 70)
|
||
|
||
start = time.time()
|
||
result = client.basicAccurate(image_data)
|
||
elapsed = time.time() - start
|
||
|
||
if "error_code" in result:
|
||
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
|
||
return None
|
||
|
||
words_result = result.get("words_result", [])
|
||
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
|
||
print("-" * 70)
|
||
for i, item in enumerate(words_result):
|
||
print(f" [{i+1:3d}] {item['words']}")
|
||
|
||
return result
|
||
|
||
|
||
def test_accurate(client, image_data):
|
||
"""测试2:通用文字识别(高精度含位置版)- accurate"""
|
||
print("\n" + "=" * 70)
|
||
print("[测试2] 通用文字识别(高精度含位置版)- accurate")
|
||
print("=" * 70)
|
||
|
||
start = time.time()
|
||
result = client.accurate(image_data)
|
||
elapsed = time.time() - start
|
||
|
||
if "error_code" in result:
|
||
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
|
||
return None
|
||
|
||
words_result = result.get("words_result", [])
|
||
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
|
||
print("-" * 70)
|
||
for i, item in enumerate(words_result):
|
||
loc = item.get("location", {})
|
||
pos_str = f"(x={loc.get('left',0)}, y={loc.get('top',0)}, w={loc.get('width',0)}, h={loc.get('height',0)})"
|
||
print(f" [{i+1:3d}] {pos_str:40s} {item['words']}")
|
||
|
||
return result
|
||
|
||
|
||
def test_general_basic(client, image_data):
|
||
"""测试3:通用文字识别(标准版)- basicGeneral"""
|
||
print("\n" + "=" * 70)
|
||
print("[测试3] 通用文字识别(标准版)- basicGeneral")
|
||
print("=" * 70)
|
||
|
||
start = time.time()
|
||
result = client.basicGeneral(image_data)
|
||
elapsed = time.time() - start
|
||
|
||
if "error_code" in result:
|
||
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
|
||
return None
|
||
|
||
words_result = result.get("words_result", [])
|
||
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
|
||
print("-" * 70)
|
||
for i, item in enumerate(words_result):
|
||
print(f" [{i+1:3d}] {item['words']}")
|
||
|
||
return result
|
||
|
||
|
||
def test_table_recognize(client, image_data):
|
||
"""测试4:表格文字识别 - tableRecognition (异步)"""
|
||
print("\n" + "=" * 70)
|
||
print("[测试4] 表格文字识别 - tableRecognitionAsync")
|
||
print("=" * 70)
|
||
|
||
# 提交表格识别请求
|
||
start = time.time()
|
||
result = client.tableRecognitionAsync(image_data)
|
||
|
||
if "error_code" in result:
|
||
print(f" [FAIL] 提交失败 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
|
||
return None
|
||
|
||
# tableRecognitionAsync 返回格式可能不同,兼容处理
|
||
result_list = result.get("result", [])
|
||
if isinstance(result_list, list) and len(result_list) > 0:
|
||
request_id = result_list[0].get("request_id", "")
|
||
elif isinstance(result_list, dict):
|
||
request_id = result_list.get("request_id", "")
|
||
else:
|
||
request_id = ""
|
||
|
||
if not request_id:
|
||
print(f" [FAIL] 未获取到 request_id,返回结果: {json.dumps(result, ensure_ascii=False)}")
|
||
return None
|
||
|
||
print(f" [INFO] 提交成功 | request_id: {request_id}")
|
||
print(" [INFO] 等待识别结果...")
|
||
|
||
# 轮询获取结果(最多等60秒)
|
||
ret_code = -1
|
||
for attempt in range(20):
|
||
time.sleep(3)
|
||
get_result = client.getTableRecognitionResult(request_id)
|
||
|
||
if "error_code" in get_result:
|
||
print(f" [FAIL] 查询失败 ({get_result['error_code']}): {get_result.get('error_msg', '')}")
|
||
return None
|
||
|
||
percent = get_result.get("result", {}).get("percent", 0)
|
||
ret_code = get_result.get("result", {}).get("ret_code", -1)
|
||
|
||
if ret_code == 3:
|
||
# 识别完成
|
||
elapsed = time.time() - start
|
||
print(f" [OK] 识别完成 | 耗时: {elapsed:.2f}s")
|
||
|
||
# 解析表格结果
|
||
result_data = get_result.get("result", {}).get("result_data", "")
|
||
if result_data:
|
||
print("-" * 70)
|
||
print(" 表格识别结果(原始):")
|
||
try:
|
||
table_data = json.loads(result_data)
|
||
formatted = json.dumps(table_data, ensure_ascii=False, indent=2)
|
||
print(formatted[:5000])
|
||
if len(formatted) > 5000:
|
||
print(" ... (结果过长,已截断)")
|
||
except Exception:
|
||
print(result_data[:5000])
|
||
|
||
return get_result
|
||
|
||
print(f" 轮询 {attempt+1}/20 | 进度: {percent}%")
|
||
|
||
elapsed = time.time() - start
|
||
print(f" [WARN] 超时(等待 {elapsed:.1f}s),最后状态: ret_code={ret_code}")
|
||
return None
|
||
|
||
|
||
def test_web_image(client, image_data):
|
||
"""测试5:网络图片文字识别 - webImage"""
|
||
print("\n" + "=" * 70)
|
||
print("[测试5] 网络图片文字识别 - webImage")
|
||
print("=" * 70)
|
||
|
||
start = time.time()
|
||
result = client.webImage(image_data)
|
||
elapsed = time.time() - start
|
||
|
||
if "error_code" in result:
|
||
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
|
||
return None
|
||
|
||
words_result = result.get("words_result", [])
|
||
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}")
|
||
print("-" * 70)
|
||
for i, item in enumerate(words_result):
|
||
print(f" [{i+1:3d}] {item['words']}")
|
||
|
||
return result
|
||
|
||
|
||
def test_table_sync(client, image_data):
|
||
"""测试6:表格识别(同步版)- form"""
|
||
print("\n" + "=" * 70)
|
||
print("[测试6] 表格识别(同步版)- form")
|
||
print("=" * 70)
|
||
|
||
start = time.time()
|
||
result = client.form(image_data)
|
||
elapsed = time.time() - start
|
||
|
||
if "error_code" in result:
|
||
print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}")
|
||
return None
|
||
|
||
forms_result = result.get("forms_result", [])
|
||
print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 表单数: {len(forms_result)}")
|
||
print("-" * 70)
|
||
|
||
# 打印表格内容
|
||
for f_idx, form in enumerate(forms_result):
|
||
print(f"\n === 表单 {f_idx + 1} ===")
|
||
header = form.get("header", [])
|
||
body = form.get("body", [])
|
||
footer = form.get("footer", [])
|
||
|
||
if header:
|
||
print(" [表头]")
|
||
for row in header:
|
||
if isinstance(row, dict):
|
||
print(f" {row.get('words', row)}")
|
||
elif isinstance(row, list):
|
||
row_text = " | ".join(
|
||
cell.get("words", str(cell)) if isinstance(cell, dict) else str(cell)
|
||
for cell in row
|
||
)
|
||
print(f" {row_text}")
|
||
|
||
if body:
|
||
print(" [表体]")
|
||
for r_idx, row in enumerate(body[:80]):
|
||
if isinstance(row, dict):
|
||
print(f" {row.get('words', row)}")
|
||
elif isinstance(row, list):
|
||
row_text = " | ".join(
|
||
cell.get("words", str(cell)) if isinstance(cell, dict) else str(cell)
|
||
for cell in row
|
||
)
|
||
print(f" {row_text}")
|
||
if len(body) > 80:
|
||
print(f" ... (共 {len(body)} 行)")
|
||
|
||
# 如果 forms_result 为空,打印原始结果
|
||
if not forms_result:
|
||
print(f" 原始结果键: {list(result.keys())}")
|
||
formatted = json.dumps(result, ensure_ascii=False, indent=2)
|
||
print(formatted[:3000])
|
||
|
||
return result
|
||
|
||
|
||
def save_results(results, output_path):
|
||
"""保存识别结果到JSON文件"""
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
json.dump(results, f, ensure_ascii=False, indent=2)
|
||
print(f"\n[SAVE] 结果已保存到: {output_path}")
|
||
|
||
|
||
def main():
|
||
print("=" * 70)
|
||
print("百度OCR识别测试 - 产品报价表图片")
|
||
print("=" * 70)
|
||
|
||
# 检查配置
|
||
if not all([APP_ID, API_KEY, SECRET_KEY]):
|
||
print("[FAIL] 百度OCR未配置,请检查 .env 文件中的 BAIDU_OCR_* 变量")
|
||
sys.exit(1)
|
||
|
||
print(f" APP_ID: {APP_ID}")
|
||
print(f" API_KEY: {API_KEY[:8]}...")
|
||
print(f" SECRET_KEY: {SECRET_KEY[:8]}...")
|
||
|
||
# 检查图片文件
|
||
if not Path(IMAGE_PATH).exists():
|
||
print(f"[FAIL] 图片文件不存在: {IMAGE_PATH}")
|
||
sys.exit(1)
|
||
|
||
file_size = Path(IMAGE_PATH).stat().st_size
|
||
print(f" 图片路径: {IMAGE_PATH}")
|
||
print(f" 文件大小: {file_size / 1024:.1f} KB")
|
||
|
||
# 初始化百度OCR客户端
|
||
client = AipOcr(APP_ID, API_KEY, SECRET_KEY)
|
||
|
||
# 读取图片
|
||
with open(IMAGE_PATH, "rb") as f:
|
||
image_data = f.read()
|
||
|
||
print(f" 图片数据: {len(image_data)} bytes")
|
||
|
||
# 收集所有测试结果
|
||
all_results = {}
|
||
|
||
# ---- 测试1:高精度版 ----
|
||
r1 = test_accurate_basic(client, image_data)
|
||
if r1:
|
||
all_results["accurate_basic"] = {
|
||
"method": "basicAccurate(高精度版)",
|
||
"lines": len(r1.get("words_result", [])),
|
||
"data": r1,
|
||
}
|
||
|
||
# ---- 测试2:高精度含位置版 ----
|
||
r2 = test_accurate(client, image_data)
|
||
if r2:
|
||
all_results["accurate_with_location"] = {
|
||
"method": "accurate(高精度含位置版)",
|
||
"lines": len(r2.get("words_result", [])),
|
||
"data": r2,
|
||
}
|
||
|
||
# ---- 测试3:标准版 ----
|
||
r3 = test_general_basic(client, image_data)
|
||
if r3:
|
||
all_results["general_basic"] = {
|
||
"method": "basicGeneral(标准版)",
|
||
"lines": len(r3.get("words_result", [])),
|
||
"data": r3,
|
||
}
|
||
|
||
# ---- 测试4:表格识别(异步) ----
|
||
r4 = test_table_recognize(client, image_data)
|
||
if r4:
|
||
all_results["table_recognition_async"] = {
|
||
"method": "tableRecognitionAsync(表格识别-异步)",
|
||
"data": r4,
|
||
}
|
||
|
||
# ---- 测试5:网络图片文字识别 ----
|
||
r5 = test_web_image(client, image_data)
|
||
if r5:
|
||
all_results["web_image"] = {
|
||
"method": "webImage(网络图片文字识别)",
|
||
"lines": len(r5.get("words_result", [])),
|
||
"data": r5,
|
||
}
|
||
|
||
# ---- 测试6:表格识别(同步) ----
|
||
r6 = test_table_sync(client, image_data)
|
||
if r6:
|
||
all_results["table_sync"] = {
|
||
"method": "form(表格识别-同步)",
|
||
"data": r6,
|
||
}
|
||
|
||
# ---- 汇总 ----
|
||
print("\n" + "=" * 70)
|
||
print("测试汇总")
|
||
print("=" * 70)
|
||
for key, val in all_results.items():
|
||
lines = val.get("lines", "N/A")
|
||
print(f" {val['method']:45s} 识别行数: {lines}")
|
||
|
||
# 保存结果
|
||
output_path = Path(__file__).parent / "test_baidu_ocr_results.json"
|
||
save_results(all_results, output_path)
|
||
|
||
print("\n[DONE] 所有测试完成!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|