Files
ztb/spiders/base.py

230 lines
7.8 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
"""
爬虫基类 - 基于 requests
"""
import csv
import logging
import os
import random
import signal
import sys
import time
from datetime import datetime
from abc import ABC, abstractmethod
from logging.handlers import RotatingFileHandler
import requests
logger = logging.getLogger("ztb")
def setup_logging(log_dir: str = "logs", level: int = logging.INFO):
"""配置日志系统:文件 + 控制台"""
os.makedirs(log_dir, exist_ok=True)
root = logging.getLogger("ztb")
if root.handlers: # 避免重复初始化
return root
root.setLevel(level)
fmt = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# 文件日志:自动轮转,单文件 5MB保留 5 个
fh = RotatingFileHandler(
os.path.join(log_dir, "spider.log"),
maxBytes=5 * 1024 * 1024,
backupCount=5,
encoding="utf-8",
)
fh.setLevel(logging.DEBUG)
fh.setFormatter(fmt)
root.addHandler(fh)
# 控制台:只输出 INFO 以上
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(fmt)
root.addHandler(ch)
return root
class BaseSpider(ABC):
"""爬虫基类"""
def __init__(self, config: dict, spider_config: dict, data_dir: str):
self.config = config
self.spider_config = spider_config
self.data_dir = data_dir
self.results = []
self._seen_urls = set() # 去重
# 安全计数器
self._total_requests = 0
self._consecutive_errors = 0
self._stopped = False
self._start_time = time.time()
self._minute_requests = [] # 每分钟请求时间戳
# HTTP 会话
self.session = requests.Session()
self.session.headers.update({
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36",
"Accept": "text/html,application/xhtml+xml,application/xml;"
"q=0.9,image/webp,*/*;q=0.8",
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
})
# 注册优雅退出
signal.signal(signal.SIGINT, self._handle_stop)
signal.signal(signal.SIGTERM, self._handle_stop)
# ---------- 安全机制 ----------
def _handle_stop(self, signum, frame):
"""捕获中断信号,保存已采集数据后退出"""
logger.warning("收到中断信号,正在保存已采集数据...")
self._stopped = True
self.save_to_csv()
sys.exit(0)
def _check_limits(self) -> bool:
"""检查是否超出安全阈值,返回 True 表示应停止"""
max_req = self.spider_config.get("max_total_requests", 300)
if self._total_requests >= max_req:
logger.warning(f"达到最大请求数 ({max_req}),停止爬取")
return True
max_err = self.spider_config.get("max_consecutive_errors", 5)
if self._consecutive_errors >= max_err:
logger.error(f"连续失败 {max_err} 次,触发熔断")
return True
return self._stopped
# ---------- 网络请求 ----------
def _throttle(self):
"""每分钟请求数限制,超出则等待"""
rpm_limit = self.spider_config.get("requests_per_minute", 10)
now = time.time()
# 清理 60s 以前的时间戳
self._minute_requests = [t for t in self._minute_requests if now - t < 60]
if len(self._minute_requests) >= rpm_limit:
wait = 60 - (now - self._minute_requests[0]) + random.uniform(1, 3)
if wait > 0:
logger.info(f"达到速率限制 ({rpm_limit}次/分钟),等待 {wait:.0f}s...")
time.sleep(wait)
self._minute_requests.append(time.time())
def fetch(self, url: str, method: str = "GET", **kwargs) -> requests.Response | None:
"""
带重试限速和安全检查的 HTTP 请求
"""
if self._check_limits():
return None
self._throttle()
timeout = kwargs.pop("timeout", self.spider_config.get("timeout", 30))
max_retries = self.spider_config.get("max_retries", 3)
for attempt in range(1, max_retries + 1):
try:
self._total_requests += 1
resp = self.session.request(method, url, timeout=timeout, **kwargs)
resp.raise_for_status()
# 检测被拦截的空响应(反爬虒返回 200 但 body 为空)
if len(resp.content) <= 10 and "json" not in resp.headers.get("Content-Type", ""):
self._consecutive_errors += 1
logger.warning(f"检测到空响应 ({len(resp.content)} bytes),可能被反爬")
if attempt < max_retries:
wait = 10 * attempt + random.uniform(5, 10)
logger.info(f"疑似被反爬拦截,等待 {wait:.0f}s 后重试...")
time.sleep(wait)
continue
return None
self._consecutive_errors = 0
return resp
except requests.RequestException as e:
self._consecutive_errors += 1
wait = 2 ** attempt + random.random()
logger.warning(f"请求失败 ({attempt}/{max_retries}): {e}{wait:.1f}s 后重试")
if attempt < max_retries:
time.sleep(wait)
logger.error(f"请求失败,已达最大重试次数: {url[:80]}")
return None
def delay(self):
"""列表页之间的随机延迟"""
lo = self.spider_config.get("delay_min", 3)
hi = self.spider_config.get("delay_max", 6)
time.sleep(random.uniform(lo, hi))
def detail_delay(self):
"""详情页请求前的随机延迟"""
lo = self.spider_config.get("detail_delay_min", 2)
hi = self.spider_config.get("detail_delay_max", 5)
time.sleep(random.uniform(lo, hi))
def print_stats(self):
"""输出爬取统计"""
elapsed = time.time() - self._start_time
rpm = self._total_requests / max(elapsed / 60, 0.1)
logger.info(f"[统计] 总请求: {self._total_requests}, "
f"耗时: {elapsed:.0f}s, 速率: {rpm:.1f}次/分钟")
# ---------- 去重 ----------
def is_duplicate(self, url: str) -> bool:
"""基于 URL 去重"""
if url in self._seen_urls:
return True
self._seen_urls.add(url)
return False
# ---------- 数据存储 ----------
def save_to_csv(self, filename: str = None):
"""保存数据到 CSV"""
if not self.results:
logger.info("没有数据可保存")
return
if not filename:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{self.config['name']}_{timestamp}.csv"
filepath = os.path.join(self.data_dir, filename)
os.makedirs(self.data_dir, exist_ok=True)
# 汇总所有字段
all_keys = []
seen = set()
for row in self.results:
for k in row:
if k not in seen:
all_keys.append(k)
seen.add(k)
with open(filepath, "w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=all_keys, extrasaction="ignore")
writer.writeheader()
writer.writerows(self.results)
logger.info(f"数据已保存到: {filepath} (共 {len(self.results)} 条记录)")
# ---------- 抽象方法 ----------
@abstractmethod
def crawl(self, max_pages: int = None, **kwargs):
"""执行爬取,子类实现"""
pass