221 lines
8.0 KiB
JavaScript
221 lines
8.0 KiB
JavaScript
const mysql = require('mysql2/promise');
|
|
|
|
let pool = null;
|
|
|
|
async function ensureColumnExists(tableName, columnName, definitionSql) {
|
|
const [rows] = await pool.query(`SHOW COLUMNS FROM \`${tableName}\` LIKE ?`, [columnName]);
|
|
if (rows.length === 0) {
|
|
await pool.execute(`ALTER TABLE \`${tableName}\` ADD COLUMN ${definitionSql}`);
|
|
}
|
|
}
|
|
|
|
async function columnMatchesType(tableName, columnName, expectedType) {
|
|
const dbName = process.env.MYSQL_DATABASE || 'bigwo_chat';
|
|
const [rows] = await pool.query(
|
|
`SELECT COLUMN_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA=? AND TABLE_NAME=? AND COLUMN_NAME=?`,
|
|
[dbName, tableName, columnName]
|
|
);
|
|
if (!rows.length) return false;
|
|
return rows[0].COLUMN_TYPE.toLowerCase().includes(expectedType.toLowerCase());
|
|
}
|
|
|
|
async function migrateSchema() {
|
|
if (!(await columnMatchesType('sessions', 'mode', "'chat'"))) {
|
|
await pool.execute("ALTER TABLE `sessions` MODIFY COLUMN `mode` ENUM('voice', 'chat') DEFAULT 'chat'");
|
|
}
|
|
if (!(await columnMatchesType('messages', 'role', "'system'"))) {
|
|
await pool.execute("ALTER TABLE `messages` MODIFY COLUMN `role` ENUM('user', 'assistant', 'tool', 'system') NOT NULL");
|
|
}
|
|
if (!(await columnMatchesType('messages', 'source', "'chat_bot'"))) {
|
|
await pool.execute("ALTER TABLE `messages` MODIFY COLUMN `source` ENUM('voice_asr', 'voice_bot', 'voice_tool', 'chat_user', 'chat_bot') NOT NULL");
|
|
}
|
|
await ensureColumnExists('messages', 'tool_name', '`tool_name` VARCHAR(64) NULL AFTER `source`');
|
|
await ensureColumnExists('messages', 'meta_json', '`meta_json` JSON NULL AFTER `tool_name`');
|
|
await ensureColumnExists('messages', 'created_at', '`created_at` BIGINT NULL AFTER `tool_name`');
|
|
await ensureColumnExists('sessions', 'updated_at', '`updated_at` BIGINT NULL AFTER `created_at`');
|
|
}
|
|
|
|
/**
|
|
* 初始化 MySQL 连接池 + 自动建表
|
|
*/
|
|
async function initialize() {
|
|
// 先连接不指定数据库,确保数据库存在
|
|
const tempConn = await mysql.createConnection({
|
|
host: process.env.MYSQL_HOST || 'localhost',
|
|
port: parseInt(process.env.MYSQL_PORT || '3306'),
|
|
user: process.env.MYSQL_USER || 'root',
|
|
password: process.env.MYSQL_PASSWORD || '',
|
|
});
|
|
const dbName = process.env.MYSQL_DATABASE || 'bigwo_chat';
|
|
await tempConn.execute(`CREATE DATABASE IF NOT EXISTS \`${dbName}\` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci`);
|
|
await tempConn.end();
|
|
|
|
// 创建连接池
|
|
pool = mysql.createPool({
|
|
host: process.env.MYSQL_HOST || 'localhost',
|
|
port: parseInt(process.env.MYSQL_PORT || '3306'),
|
|
user: process.env.MYSQL_USER || 'root',
|
|
password: process.env.MYSQL_PASSWORD || '',
|
|
database: dbName,
|
|
waitForConnections: true,
|
|
connectionLimit: 10,
|
|
charset: 'utf8mb4',
|
|
});
|
|
|
|
// 建表
|
|
await pool.execute(`
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
id VARCHAR(128) PRIMARY KEY,
|
|
user_id VARCHAR(128),
|
|
mode ENUM('voice', 'chat') DEFAULT 'chat',
|
|
created_at BIGINT,
|
|
updated_at BIGINT,
|
|
INDEX idx_user (user_id),
|
|
INDEX idx_updated (updated_at)
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
|
`);
|
|
|
|
await pool.execute(`
|
|
CREATE TABLE IF NOT EXISTS messages (
|
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
|
session_id VARCHAR(128) NOT NULL,
|
|
role ENUM('user', 'assistant', 'tool', 'system') NOT NULL,
|
|
content TEXT NOT NULL,
|
|
source ENUM('voice_asr', 'voice_bot', 'voice_tool', 'chat_user', 'chat_bot') NOT NULL,
|
|
tool_name VARCHAR(64),
|
|
meta_json JSON,
|
|
created_at BIGINT,
|
|
INDEX idx_session (session_id),
|
|
INDEX idx_session_time (session_id, created_at)
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
|
`);
|
|
|
|
await migrateSchema();
|
|
|
|
console.log(`[DB] MySQL connected: ${dbName}, tables ready`);
|
|
return pool;
|
|
}
|
|
|
|
/**
|
|
* 获取连接池
|
|
*/
|
|
function getPool() {
|
|
if (!pool) throw new Error('[DB] Not initialized. Call initialize() first.');
|
|
return pool;
|
|
}
|
|
|
|
// ==================== Sessions ====================
|
|
|
|
async function createSession(sessionId, userId, mode = 'chat') {
|
|
const now = Date.now();
|
|
await pool.execute(
|
|
'INSERT INTO sessions (id, user_id, mode, created_at, updated_at) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE mode=VALUES(mode), updated_at=VALUES(updated_at)',
|
|
[sessionId, userId || null, mode, now, now]
|
|
);
|
|
return { id: sessionId, userId, mode, created_at: now };
|
|
}
|
|
|
|
async function updateSessionMode(sessionId, mode) {
|
|
await pool.execute(
|
|
'UPDATE sessions SET mode=?, updated_at=? WHERE id=?',
|
|
[mode, Date.now(), sessionId]
|
|
);
|
|
}
|
|
|
|
async function getSession(sessionId) {
|
|
const [rows] = await pool.execute('SELECT * FROM sessions WHERE id=?', [sessionId]);
|
|
return rows[0] || null;
|
|
}
|
|
|
|
// ==================== Messages ====================
|
|
|
|
async function addMessage(sessionId, role, content, source, toolName = null, meta = null) {
|
|
if (!content || content.trim() === '') return null;
|
|
const now = Date.now();
|
|
const metaJson = meta == null ? null : JSON.stringify(meta);
|
|
const [result] = await pool.execute(
|
|
'INSERT INTO messages (session_id, role, content, source, tool_name, meta_json, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)',
|
|
[sessionId, role, content, source, toolName, metaJson, now]
|
|
);
|
|
// 更新 session 时间
|
|
await pool.execute('UPDATE sessions SET updated_at=? WHERE id=?', [now, sessionId]);
|
|
return { id: result.insertId, session_id: sessionId, role, content, source, tool_name: toolName, meta_json: metaJson, created_at: now };
|
|
}
|
|
|
|
async function getMessages(sessionId, limit = 20) {
|
|
const safeLimit = Math.max(1, Math.min(parseInt(limit) || 20, 100));
|
|
const [rows] = await pool.query(
|
|
'SELECT role, content, source, tool_name, meta_json, created_at FROM messages WHERE session_id=? ORDER BY created_at ASC LIMIT ?',
|
|
[sessionId, safeLimit]
|
|
);
|
|
return rows;
|
|
}
|
|
|
|
async function getRecentMessages(sessionId, limit = 20) {
|
|
// 获取最近 N 条,按时间正序返回
|
|
const safeLimit = Math.max(1, Math.min(parseInt(limit) || 20, 100));
|
|
const [rows] = await pool.query(
|
|
`SELECT role, content, source, tool_name, meta_json, created_at FROM messages
|
|
WHERE session_id=? ORDER BY created_at DESC LIMIT ?`,
|
|
[sessionId, safeLimit]
|
|
);
|
|
return rows.reverse();
|
|
}
|
|
|
|
/**
|
|
* 获取会话历史(格式化为 LLM 可用的 {role, content} 数组)
|
|
* 合并 tool 消息到 assistant 消息
|
|
*/
|
|
async function getHistoryForLLM(sessionId, limit = 20) {
|
|
const messages = await getRecentMessages(sessionId, limit);
|
|
return messages
|
|
.filter(m => m.role === 'user' || m.role === 'assistant')
|
|
.map(m => ({ role: m.role, content: m.content }));
|
|
}
|
|
|
|
/**
|
|
* 获取会话列表(按更新时间倒序,带最后一条消息预览)
|
|
*/
|
|
async function getSessionList(userId, limit = 50) {
|
|
const safeLimit = Math.max(1, Math.min(parseInt(limit) || 50, 200));
|
|
let query;
|
|
let params;
|
|
if (userId) {
|
|
query = `SELECT s.id, s.user_id, s.mode, s.created_at, s.updated_at,
|
|
(SELECT content FROM messages WHERE session_id = s.id ORDER BY created_at DESC LIMIT 1) AS last_message,
|
|
(SELECT COUNT(*) FROM messages WHERE session_id = s.id) AS message_count
|
|
FROM sessions s WHERE s.user_id = ? ORDER BY s.updated_at DESC LIMIT ?`;
|
|
params = [userId, safeLimit];
|
|
} else {
|
|
query = `SELECT s.id, s.user_id, s.mode, s.created_at, s.updated_at,
|
|
(SELECT content FROM messages WHERE session_id = s.id ORDER BY created_at DESC LIMIT 1) AS last_message,
|
|
(SELECT COUNT(*) FROM messages WHERE session_id = s.id) AS message_count
|
|
FROM sessions s ORDER BY s.updated_at DESC LIMIT ?`;
|
|
params = [safeLimit];
|
|
}
|
|
const [rows] = await pool.query(query, params);
|
|
return rows;
|
|
}
|
|
|
|
/**
|
|
* 删除会话及其所有消息
|
|
*/
|
|
async function deleteSession(sessionId) {
|
|
await pool.execute('DELETE FROM messages WHERE session_id = ?', [sessionId]);
|
|
await pool.execute('DELETE FROM sessions WHERE id = ?', [sessionId]);
|
|
}
|
|
|
|
module.exports = {
|
|
initialize,
|
|
getPool,
|
|
createSession,
|
|
updateSessionMode,
|
|
getSession,
|
|
addMessage,
|
|
getMessages,
|
|
getRecentMessages,
|
|
getHistoryForLLM,
|
|
getSessionList,
|
|
deleteSession,
|
|
};
|