const mysql = require('mysql2/promise'); let pool = null; async function ensureColumnExists(tableName, columnName, definitionSql) { const [rows] = await pool.query(`SHOW COLUMNS FROM \`${tableName}\` LIKE ?`, [columnName]); if (rows.length === 0) { await pool.execute(`ALTER TABLE \`${tableName}\` ADD COLUMN ${definitionSql}`); } } async function migrateSchema() { await pool.execute("ALTER TABLE `sessions` MODIFY COLUMN `mode` ENUM('voice', 'chat') DEFAULT 'chat'"); await pool.execute("ALTER TABLE `messages` MODIFY COLUMN `role` ENUM('user', 'assistant', 'tool', 'system') NOT NULL"); await pool.execute("ALTER TABLE `messages` MODIFY COLUMN `source` ENUM('voice_asr', 'voice_bot', 'voice_tool', 'chat_user', 'chat_bot') NOT NULL"); await ensureColumnExists('messages', 'tool_name', '`tool_name` VARCHAR(64) NULL AFTER `source`'); await ensureColumnExists('messages', 'meta_json', '`meta_json` JSON NULL AFTER `tool_name`'); await ensureColumnExists('messages', 'created_at', '`created_at` BIGINT NULL AFTER `tool_name`'); await ensureColumnExists('sessions', 'updated_at', '`updated_at` BIGINT NULL AFTER `created_at`'); } /** * 初始化 MySQL 连接池 + 自动建表 */ async function initialize() { // 先连接不指定数据库,确保数据库存在 const tempConn = await mysql.createConnection({ host: process.env.MYSQL_HOST || 'localhost', port: parseInt(process.env.MYSQL_PORT || '3306'), user: process.env.MYSQL_USER || 'root', password: process.env.MYSQL_PASSWORD || '', }); const dbName = process.env.MYSQL_DATABASE || 'bigwo_chat'; await tempConn.execute(`CREATE DATABASE IF NOT EXISTS \`${dbName}\` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci`); await tempConn.end(); // 创建连接池 pool = mysql.createPool({ host: process.env.MYSQL_HOST || 'localhost', port: parseInt(process.env.MYSQL_PORT || '3306'), user: process.env.MYSQL_USER || 'root', password: process.env.MYSQL_PASSWORD || '', database: dbName, waitForConnections: true, connectionLimit: 10, charset: 'utf8mb4', }); // 建表 await pool.execute(` CREATE TABLE IF NOT EXISTS sessions ( id VARCHAR(128) PRIMARY KEY, user_id VARCHAR(128), mode ENUM('voice', 'chat') DEFAULT 'chat', created_at BIGINT, updated_at BIGINT, INDEX idx_user (user_id), INDEX idx_updated (updated_at) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 `); await pool.execute(` CREATE TABLE IF NOT EXISTS messages ( id INT AUTO_INCREMENT PRIMARY KEY, session_id VARCHAR(128) NOT NULL, role ENUM('user', 'assistant', 'tool', 'system') NOT NULL, content TEXT NOT NULL, source ENUM('voice_asr', 'voice_bot', 'voice_tool', 'chat_user', 'chat_bot') NOT NULL, tool_name VARCHAR(64), meta_json JSON, created_at BIGINT, INDEX idx_session (session_id), INDEX idx_session_time (session_id, created_at) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 `); await migrateSchema(); console.log(`[DB] MySQL connected: ${dbName}, tables ready`); return pool; } /** * 获取连接池 */ function getPool() { if (!pool) throw new Error('[DB] Not initialized. Call initialize() first.'); return pool; } // ==================== Sessions ==================== async function createSession(sessionId, userId, mode = 'chat') { const now = Date.now(); await pool.execute( 'INSERT INTO sessions (id, user_id, mode, created_at, updated_at) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE mode=VALUES(mode), updated_at=VALUES(updated_at)', [sessionId, userId || null, mode, now, now] ); return { id: sessionId, userId, mode, created_at: now }; } async function updateSessionMode(sessionId, mode) { await pool.execute( 'UPDATE sessions SET mode=?, updated_at=? WHERE id=?', [mode, Date.now(), sessionId] ); } async function getSession(sessionId) { const [rows] = await pool.execute('SELECT * FROM sessions WHERE id=?', [sessionId]); return rows[0] || null; } // ==================== Messages ==================== async function addMessage(sessionId, role, content, source, toolName = null, meta = null) { if (!content || content.trim() === '') return null; const now = Date.now(); const metaJson = meta == null ? null : JSON.stringify(meta); const [result] = await pool.execute( 'INSERT INTO messages (session_id, role, content, source, tool_name, meta_json, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)', [sessionId, role, content, source, toolName, metaJson, now] ); // 更新 session 时间 await pool.execute('UPDATE sessions SET updated_at=? WHERE id=?', [now, sessionId]); return { id: result.insertId, session_id: sessionId, role, content, source, tool_name: toolName, meta_json: metaJson, created_at: now }; } async function getMessages(sessionId, limit = 20) { const safeLimit = Math.max(1, Math.min(parseInt(limit) || 20, 100)); const [rows] = await pool.query( 'SELECT role, content, source, tool_name, meta_json, created_at FROM messages WHERE session_id=? ORDER BY created_at ASC LIMIT ?', [sessionId, safeLimit] ); return rows; } async function getRecentMessages(sessionId, limit = 20) { // 获取最近 N 条,按时间正序返回 const safeLimit = Math.max(1, Math.min(parseInt(limit) || 20, 100)); const [rows] = await pool.query( `SELECT role, content, source, tool_name, meta_json, created_at FROM messages WHERE session_id=? ORDER BY created_at DESC LIMIT ?`, [sessionId, safeLimit] ); return rows.reverse(); } /** * 获取会话历史(格式化为 LLM 可用的 {role, content} 数组) * 合并 tool 消息到 assistant 消息 */ async function getHistoryForLLM(sessionId, limit = 20) { const messages = await getRecentMessages(sessionId, limit); return messages .filter(m => m.role === 'user' || m.role === 'assistant') .map(m => ({ role: m.role, content: m.content })); } /** * 获取会话列表(按更新时间倒序,带最后一条消息预览) */ async function getSessionList(userId, limit = 50) { const safeLimit = Math.max(1, Math.min(parseInt(limit) || 50, 200)); let query; let params; if (userId) { query = `SELECT s.id, s.user_id, s.mode, s.created_at, s.updated_at, (SELECT content FROM messages WHERE session_id = s.id ORDER BY created_at DESC LIMIT 1) AS last_message, (SELECT COUNT(*) FROM messages WHERE session_id = s.id) AS message_count FROM sessions s WHERE s.user_id = ? ORDER BY s.updated_at DESC LIMIT ?`; params = [userId, safeLimit]; } else { query = `SELECT s.id, s.user_id, s.mode, s.created_at, s.updated_at, (SELECT content FROM messages WHERE session_id = s.id ORDER BY created_at DESC LIMIT 1) AS last_message, (SELECT COUNT(*) FROM messages WHERE session_id = s.id) AS message_count FROM sessions s ORDER BY s.updated_at DESC LIMIT ?`; params = [safeLimit]; } const [rows] = await pool.query(query, params); return rows; } /** * 删除会话及其所有消息 */ async function deleteSession(sessionId) { await pool.execute('DELETE FROM messages WHERE session_id = ?', [sessionId]); await pool.execute('DELETE FROM sessions WHERE id = ?', [sessionId]); } module.exports = { initialize, getPool, createSession, updateSessionMode, getSession, addMessage, getMessages, getRecentMessages, getHistoryForLLM, getSessionList, deleteSession, };