Files
ztb/spiders/base.py
2026-02-25 18:17:00 +08:00

262 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
爬虫基类 - 基于 requests
"""
import csv
import logging
import os
import random
import re
import signal
import sys
import time
from datetime import datetime
from abc import ABC, abstractmethod
from logging.handlers import RotatingFileHandler
import requests
logger = logging.getLogger("ztb")
def setup_logging(log_dir: str = "logs", level: int = logging.INFO):
"""配置日志系统:文件 + 控制台"""
os.makedirs(log_dir, exist_ok=True)
root = logging.getLogger("ztb")
if root.handlers: # 避免重复初始化
return root
root.setLevel(level)
fmt = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# 文件日志:自动轮转,单文件 5MB保留 5 个
fh = RotatingFileHandler(
os.path.join(log_dir, "spider.log"),
maxBytes=5 * 1024 * 1024,
backupCount=5,
encoding="utf-8",
)
fh.setLevel(logging.DEBUG)
fh.setFormatter(fmt)
root.addHandler(fh)
# 控制台:只输出 INFO 以上
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(fmt)
root.addHandler(ch)
return root
class BaseSpider(ABC):
"""爬虫基类"""
def __init__(self, config: dict, spider_config: dict, data_dir: str):
self.config = config
self.spider_config = spider_config
self.data_dir = data_dir
self.results = []
self._seen_urls = set() # 去重
# 安全计数器
self._total_requests = 0
self._consecutive_errors = 0
self._stopped = False
self._start_time = time.time()
self._minute_requests = [] # 每分钟请求时间戳
# HTTP 会话
self.session = requests.Session()
self.session.headers.update({
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36",
"Accept": "text/html,application/xhtml+xml,application/xml;"
"q=0.9,image/webp,*/*;q=0.8",
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
})
# 注册优雅退出
signal.signal(signal.SIGINT, self._handle_stop)
signal.signal(signal.SIGTERM, self._handle_stop)
# ---------- 安全机制 ----------
def _handle_stop(self, signum, frame):
"""捕获中断信号,保存已采集数据后退出"""
logger.warning("收到中断信号,正在保存已采集数据...")
self._stopped = True
self.save_to_csv()
sys.exit(0)
def _check_limits(self) -> bool:
"""检查是否超出安全阈值,返回 True 表示应停止"""
max_req = self.spider_config.get("max_total_requests", 300)
if self._total_requests >= max_req:
logger.warning(f"达到最大请求数 ({max_req}),停止爬取")
return True
max_err = self.spider_config.get("max_consecutive_errors", 5)
if self._consecutive_errors >= max_err:
logger.error(f"连续失败 {max_err} 次,触发熔断")
return True
return self._stopped
# ---------- 网络请求 ----------
def _throttle(self):
"""每分钟请求数限制,超出则等待"""
rpm_limit = self.spider_config.get("requests_per_minute", 10)
now = time.time()
# 清理 60s 以前的时间戳
self._minute_requests = [t for t in self._minute_requests if now - t < 60]
if len(self._minute_requests) >= rpm_limit:
wait = 60 - (now - self._minute_requests[0]) + random.uniform(1, 3)
if wait > 0:
logger.info(f"达到速率限制 ({rpm_limit}次/分钟),等待 {wait:.0f}s...")
time.sleep(wait)
self._minute_requests.append(time.time())
def fetch(self, url: str, method: str = "GET", **kwargs) -> requests.Response | None:
"""
带重试、限速和安全检查的 HTTP 请求
"""
if self._check_limits():
return None
self._throttle()
timeout = kwargs.pop("timeout", self.spider_config.get("timeout", 30))
max_retries = self.spider_config.get("max_retries", 3)
for attempt in range(1, max_retries + 1):
try:
self._total_requests += 1
resp = self.session.request(method, url, timeout=timeout, **kwargs)
resp.raise_for_status()
# 检测被拦截的空响应(反爬虒返回 200 但 body 为空)
if len(resp.content) <= 10 and "json" not in resp.headers.get("Content-Type", ""):
self._consecutive_errors += 1
logger.warning(f"检测到空响应 ({len(resp.content)} bytes),可能被反爬")
if attempt < max_retries:
wait = 10 * attempt + random.uniform(5, 10)
logger.info(f"疑似被反爬拦截,等待 {wait:.0f}s 后重试...")
time.sleep(wait)
continue
return None
self._consecutive_errors = 0
return resp
except requests.RequestException as e:
self._consecutive_errors += 1
wait = 2 ** attempt + random.random()
logger.warning(f"请求失败 ({attempt}/{max_retries}): {e}{wait:.1f}s 后重试")
if attempt < max_retries:
time.sleep(wait)
logger.error(f"请求失败,已达最大重试次数: {url[:80]}")
return None
def delay(self):
"""列表页之间的随机延迟"""
lo = self.spider_config.get("delay_min", 3)
hi = self.spider_config.get("delay_max", 6)
time.sleep(random.uniform(lo, hi))
def detail_delay(self):
"""详情页请求前的随机延迟"""
lo = self.spider_config.get("detail_delay_min", 2)
hi = self.spider_config.get("detail_delay_max", 5)
time.sleep(random.uniform(lo, hi))
def print_stats(self):
"""输出爬取统计"""
elapsed = time.time() - self._start_time
rpm = self._total_requests / max(elapsed / 60, 0.1)
logger.info(f"[统计] 总请求: {self._total_requests}, "
f"耗时: {elapsed:.0f}s, 速率: {rpm:.1f}次/分钟")
# ---------- 标题解析(统一规则) ----------
@staticmethod
def _parse_title(title: str) -> dict:
"""从标题中提取项目名称和批准文号(统一规则)"""
result = {}
# 统一正则:前缀可选,贪婪匹配项目名称,提取尾部批准文号
title_pattern = r"(?:\[(?:招标文件|招标公告)\])?\s*(.*)\s*\[([A-Z0-9]+)\]\s*$"
match = re.search(title_pattern, title)
if match:
project_name = match.group(1).strip()
result["项目批准文号"] = match.group(2).strip()
else:
project_name = title
# 尝试从标题尾部提取批准文号
number_pattern = r"\[([A-Z0-9]+)\]\s*$"
match = re.search(number_pattern, project_name)
if match:
result["项目批准文号"] = match.group(1).strip()
project_name = project_name[:match.start()].strip()
# 清理项目名称后缀
suffixes = ["招标文件公示", "招标文件预公示", "招标公告", "招标预公告"]
for suffix in suffixes:
if project_name.endswith(suffix):
project_name = project_name[:-len(suffix)].strip()
result["项目名称"] = project_name
return result
# ---------- 去重 ----------
def is_duplicate(self, url: str) -> bool:
"""基于 URL 去重"""
if url in self._seen_urls:
return True
self._seen_urls.add(url)
return False
# ---------- 数据存储 ----------
def save_to_csv(self, filename: str = None):
"""保存数据到 CSV"""
if not self.results:
logger.info("没有数据可保存")
return
if not filename:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{self.config['name']}_{timestamp}.csv"
filepath = os.path.join(self.data_dir, filename)
os.makedirs(self.data_dir, exist_ok=True)
# 汇总所有字段
all_keys = []
seen = set()
for row in self.results:
for k in row:
if k not in seen:
all_keys.append(k)
seen.add(k)
with open(filepath, "w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=all_keys, extrasaction="ignore")
writer.writeheader()
writer.writerows(self.results)
logger.info(f"数据已保存到: {filepath} (共 {len(self.results)} 条记录)")
# ---------- 抽象方法 ----------
@abstractmethod
def crawl(self, max_pages: int = None, **kwargs):
"""执行爬取,子类实现"""
pass