# -*- coding: utf-8 -*- """ 百度OCR识别测试脚本 测试目标:使用百度OCR对产品报价表图片进行识别,验证识别效果 """ import os import sys import io import json import time from pathlib import Path # 修复 Windows 终端 UTF-8 输出 sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace") sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace") # 加载环境变量 from dotenv import load_dotenv load_dotenv(Path(__file__).parent / ".env") from aip import AipOcr # ==================== 配置 ==================== APP_ID = os.getenv("BAIDU_OCR_APP_ID", "") API_KEY = os.getenv("BAIDU_OCR_API_KEY", "") SECRET_KEY = os.getenv("BAIDU_OCR_SECRET_KEY", "") # 测试图片路径 IMAGE_PATH = r"C:\Users\UI\.cursor\projects\c-Users-UI-Desktop\assets\c__Users_UI_AppData_Roaming_Cursor_User_workspaceStorage_6df83b93d4a0651428307542725e79d8_images_ecdbe509-3f63-49c0-a8be-db9facaef857_3_-4dec6c0d-a755-4bda-8780-9e6b20e02df8.png" def test_accurate_basic(client, image_data): """测试1:通用文字识别(高精度版)- basicAccurate""" print("\n" + "=" * 70) print("[测试1] 通用文字识别(高精度版)- basicAccurate") print("=" * 70) start = time.time() result = client.basicAccurate(image_data) elapsed = time.time() - start if "error_code" in result: print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}") return None words_result = result.get("words_result", []) print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}") print("-" * 70) for i, item in enumerate(words_result): print(f" [{i+1:3d}] {item['words']}") return result def test_accurate(client, image_data): """测试2:通用文字识别(高精度含位置版)- accurate""" print("\n" + "=" * 70) print("[测试2] 通用文字识别(高精度含位置版)- accurate") print("=" * 70) start = time.time() result = client.accurate(image_data) elapsed = time.time() - start if "error_code" in result: print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}") return None words_result = result.get("words_result", []) print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}") print("-" * 70) for i, item in enumerate(words_result): loc = item.get("location", {}) pos_str = f"(x={loc.get('left',0)}, y={loc.get('top',0)}, w={loc.get('width',0)}, h={loc.get('height',0)})" print(f" [{i+1:3d}] {pos_str:40s} {item['words']}") return result def test_general_basic(client, image_data): """测试3:通用文字识别(标准版)- basicGeneral""" print("\n" + "=" * 70) print("[测试3] 通用文字识别(标准版)- basicGeneral") print("=" * 70) start = time.time() result = client.basicGeneral(image_data) elapsed = time.time() - start if "error_code" in result: print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}") return None words_result = result.get("words_result", []) print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}") print("-" * 70) for i, item in enumerate(words_result): print(f" [{i+1:3d}] {item['words']}") return result def test_table_recognize(client, image_data): """测试4:表格文字识别 - tableRecognition (异步)""" print("\n" + "=" * 70) print("[测试4] 表格文字识别 - tableRecognitionAsync") print("=" * 70) # 提交表格识别请求 start = time.time() result = client.tableRecognitionAsync(image_data) if "error_code" in result: print(f" [FAIL] 提交失败 ({result['error_code']}): {result.get('error_msg', '未知错误')}") return None # tableRecognitionAsync 返回格式可能不同,兼容处理 result_list = result.get("result", []) if isinstance(result_list, list) and len(result_list) > 0: request_id = result_list[0].get("request_id", "") elif isinstance(result_list, dict): request_id = result_list.get("request_id", "") else: request_id = "" if not request_id: print(f" [FAIL] 未获取到 request_id,返回结果: {json.dumps(result, ensure_ascii=False)}") return None print(f" [INFO] 提交成功 | request_id: {request_id}") print(" [INFO] 等待识别结果...") # 轮询获取结果(最多等60秒) ret_code = -1 for attempt in range(20): time.sleep(3) get_result = client.getTableRecognitionResult(request_id) if "error_code" in get_result: print(f" [FAIL] 查询失败 ({get_result['error_code']}): {get_result.get('error_msg', '')}") return None percent = get_result.get("result", {}).get("percent", 0) ret_code = get_result.get("result", {}).get("ret_code", -1) if ret_code == 3: # 识别完成 elapsed = time.time() - start print(f" [OK] 识别完成 | 耗时: {elapsed:.2f}s") # 解析表格结果 result_data = get_result.get("result", {}).get("result_data", "") if result_data: print("-" * 70) print(" 表格识别结果(原始):") try: table_data = json.loads(result_data) formatted = json.dumps(table_data, ensure_ascii=False, indent=2) print(formatted[:5000]) if len(formatted) > 5000: print(" ... (结果过长,已截断)") except Exception: print(result_data[:5000]) return get_result print(f" 轮询 {attempt+1}/20 | 进度: {percent}%") elapsed = time.time() - start print(f" [WARN] 超时(等待 {elapsed:.1f}s),最后状态: ret_code={ret_code}") return None def test_web_image(client, image_data): """测试5:网络图片文字识别 - webImage""" print("\n" + "=" * 70) print("[测试5] 网络图片文字识别 - webImage") print("=" * 70) start = time.time() result = client.webImage(image_data) elapsed = time.time() - start if "error_code" in result: print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}") return None words_result = result.get("words_result", []) print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 识别行数: {len(words_result)}") print("-" * 70) for i, item in enumerate(words_result): print(f" [{i+1:3d}] {item['words']}") return result def test_table_sync(client, image_data): """测试6:表格识别(同步版)- form""" print("\n" + "=" * 70) print("[测试6] 表格识别(同步版)- form") print("=" * 70) start = time.time() result = client.form(image_data) elapsed = time.time() - start if "error_code" in result: print(f" [FAIL] 错误 ({result['error_code']}): {result.get('error_msg', '未知错误')}") return None forms_result = result.get("forms_result", []) print(f" [OK] 识别成功 | 耗时: {elapsed:.2f}s | 表单数: {len(forms_result)}") print("-" * 70) # 打印表格内容 for f_idx, form in enumerate(forms_result): print(f"\n === 表单 {f_idx + 1} ===") header = form.get("header", []) body = form.get("body", []) footer = form.get("footer", []) if header: print(" [表头]") for row in header: if isinstance(row, dict): print(f" {row.get('words', row)}") elif isinstance(row, list): row_text = " | ".join( cell.get("words", str(cell)) if isinstance(cell, dict) else str(cell) for cell in row ) print(f" {row_text}") if body: print(" [表体]") for r_idx, row in enumerate(body[:80]): if isinstance(row, dict): print(f" {row.get('words', row)}") elif isinstance(row, list): row_text = " | ".join( cell.get("words", str(cell)) if isinstance(cell, dict) else str(cell) for cell in row ) print(f" {row_text}") if len(body) > 80: print(f" ... (共 {len(body)} 行)") # 如果 forms_result 为空,打印原始结果 if not forms_result: print(f" 原始结果键: {list(result.keys())}") formatted = json.dumps(result, ensure_ascii=False, indent=2) print(formatted[:3000]) return result def save_results(results, output_path): """保存识别结果到JSON文件""" with open(output_path, "w", encoding="utf-8") as f: json.dump(results, f, ensure_ascii=False, indent=2) print(f"\n[SAVE] 结果已保存到: {output_path}") def main(): print("=" * 70) print("百度OCR识别测试 - 产品报价表图片") print("=" * 70) # 检查配置 if not all([APP_ID, API_KEY, SECRET_KEY]): print("[FAIL] 百度OCR未配置,请检查 .env 文件中的 BAIDU_OCR_* 变量") sys.exit(1) print(f" APP_ID: {APP_ID}") print(f" API_KEY: {API_KEY[:8]}...") print(f" SECRET_KEY: {SECRET_KEY[:8]}...") # 检查图片文件 if not Path(IMAGE_PATH).exists(): print(f"[FAIL] 图片文件不存在: {IMAGE_PATH}") sys.exit(1) file_size = Path(IMAGE_PATH).stat().st_size print(f" 图片路径: {IMAGE_PATH}") print(f" 文件大小: {file_size / 1024:.1f} KB") # 初始化百度OCR客户端 client = AipOcr(APP_ID, API_KEY, SECRET_KEY) # 读取图片 with open(IMAGE_PATH, "rb") as f: image_data = f.read() print(f" 图片数据: {len(image_data)} bytes") # 收集所有测试结果 all_results = {} # ---- 测试1:高精度版 ---- r1 = test_accurate_basic(client, image_data) if r1: all_results["accurate_basic"] = { "method": "basicAccurate(高精度版)", "lines": len(r1.get("words_result", [])), "data": r1, } # ---- 测试2:高精度含位置版 ---- r2 = test_accurate(client, image_data) if r2: all_results["accurate_with_location"] = { "method": "accurate(高精度含位置版)", "lines": len(r2.get("words_result", [])), "data": r2, } # ---- 测试3:标准版 ---- r3 = test_general_basic(client, image_data) if r3: all_results["general_basic"] = { "method": "basicGeneral(标准版)", "lines": len(r3.get("words_result", [])), "data": r3, } # ---- 测试4:表格识别(异步) ---- r4 = test_table_recognize(client, image_data) if r4: all_results["table_recognition_async"] = { "method": "tableRecognitionAsync(表格识别-异步)", "data": r4, } # ---- 测试5:网络图片文字识别 ---- r5 = test_web_image(client, image_data) if r5: all_results["web_image"] = { "method": "webImage(网络图片文字识别)", "lines": len(r5.get("words_result", [])), "data": r5, } # ---- 测试6:表格识别(同步) ---- r6 = test_table_sync(client, image_data) if r6: all_results["table_sync"] = { "method": "form(表格识别-同步)", "data": r6, } # ---- 汇总 ---- print("\n" + "=" * 70) print("测试汇总") print("=" * 70) for key, val in all_results.items(): lines = val.get("lines", "N/A") print(f" {val['method']:45s} 识别行数: {lines}") # 保存结果 output_path = Path(__file__).parent / "test_baidu_ocr_results.json" save_results(all_results, output_path) print("\n[DONE] 所有测试完成!") if __name__ == "__main__": main()