feat: 添加realtime_dialog和realtime_dialog_external_rag_test项目,更新test2项目

This commit is contained in:
User
2026-03-13 13:06:46 +08:00
parent 9dab61345c
commit 5521b673f5
215 changed files with 7626 additions and 1876 deletions

View File

@@ -1 +1 @@
34689:1773278288416_21076_pox4drkjp4mr:1773290598977
34689:1773367134016_22200_768qlmindc7g:1773378240856

View File

@@ -1,7 +1,7 @@
{
"token": "53fec3387174916c212b85c8cc00d71e745ccc3bc25851295aa18cc63e129e1d",
"expires_at": 1773291780,
"token": "47c96fae6ababba95306064d77f8cb9d0ab937523d5c08825e3061a7617b902d",
"expires_at": 1773378566,
"card_type": 7,
"card_expires_at": 1773883128,
"created_at": 1773289980
"created_at": 1773376766
}

View File

@@ -119,25 +119,32 @@ async function main() {
const serverFiles = [
{ local: "server\\app.js", remote: `${PROJECT}/server/app.js` },
{ local: "server\\package.json", remote: `${PROJECT}/server/package.json` },
{ local: "server\\package-lock.json", remote: `${PROJECT}/server/package-lock.json` },
{ local: "server\\routes\\chat.js", remote: `${PROJECT}/server/routes/chat.js` },
{ local: "server\\routes\\session.js", remote: `${PROJECT}/server/routes/session.js` },
{ local: "server\\routes\\voice.js", remote: `${PROJECT}/server/routes/voice.js` },
{ local: "server\\services\\arkChatService.js", remote: `${PROJECT}/server/services/arkChatService.js` },
{ local: "server\\services\\cozeChatService.js", remote: `${PROJECT}/server/services/cozeChatService.js` },
{ local: "server\\services\\nativeVoiceGateway.js", remote: `${PROJECT}/server/services/nativeVoiceGateway.js` },
{ local: "server\\services\\realtimeDialogProtocol.js", remote: `${PROJECT}/server/services/realtimeDialogProtocol.js` },
{ local: "server\\services\\realtimeDialogRouting.js", remote: `${PROJECT}/server/services/realtimeDialogRouting.js` },
{ local: "server\\services\\toolExecutor.js", remote: `${PROJECT}/server/services/toolExecutor.js` },
{ local: "server\\services\\volcengine.js", remote: `${PROJECT}/server/services/volcengine.js` },
{ local: "server\\config\\tools.js", remote: `${PROJECT}/server/config/tools.js` },
{ local: "server\\config\\voiceChatConfig.js", remote: `${PROJECT}/server/config/voiceChatConfig.js` },
{ local: "server\\db\\index.js", remote: `${PROJECT}/server/db/index.js` },
{ local: "server\\lib\\token.js", remote: `${PROJECT}/server/lib/token.js` },
];
const clientFiles = [
{ local: "client\\dist\\index.html", remote: `${PROJECT}/client/dist/index.html` },
{ local: "client\\dist\\assets\\index-DR-ymgvy.css", remote: `${PROJECT}/client/dist/assets/index-DR-ymgvy.css` },
{ local: "client\\dist\\assets\\index-DV4vMa2s.js", remote: `${PROJECT}/client/dist/assets/index-DV4vMa2s.js` },
{ local: "client\\dist\\assets\\index.esm.min-C5F81t8Q.js", remote: `${PROJECT}/client/dist/assets/index.esm.min-C5F81t8Q.js` },
];
const localAssetsDir = join(LOCAL_BASE, "client", "dist", "assets");
const assetNames = readdirSync(localAssetsDir).filter((name) => statSync(join(localAssetsDir, name)).isFile());
assetNames.forEach((name) => {
clientFiles.push({
local: `client\\dist\\assets\\${name}`,
remote: `${PROJECT}/client/dist/assets/${name}`,
});
});
console.log("=== 1. 检查服务器状态 ===");
const pm2Check = await sshExec("pm2 list 2>&1 | head -20");
@@ -150,17 +157,19 @@ async function main() {
const backupFiles = [
`${PROJECT}/server/app.js`,
`${PROJECT}/server/package.json`,
`${PROJECT}/server/package-lock.json`,
`${PROJECT}/server/routes/chat.js`,
`${PROJECT}/server/routes/session.js`,
`${PROJECT}/server/routes/voice.js`,
`${PROJECT}/server/services/arkChatService.js`,
`${PROJECT}/server/services/cozeChatService.js`,
`${PROJECT}/server/services/nativeVoiceGateway.js`,
`${PROJECT}/server/services/realtimeDialogProtocol.js`,
`${PROJECT}/server/services/realtimeDialogRouting.js`,
`${PROJECT}/server/services/toolExecutor.js`,
`${PROJECT}/server/services/volcengine.js`,
`${PROJECT}/server/config/tools.js`,
`${PROJECT}/server/config/voiceChatConfig.js`,
`${PROJECT}/server/db/index.js`,
`${PROJECT}/server/lib/token.js`,
`${PROJECT}/client/dist/index.html`,
];
@@ -170,6 +179,10 @@ async function main() {
}
console.log("备份完成");
console.log("\n=== 2b. 删除已废弃的 RTC 文件 ===");
const removeRtc = await sshExec(`rm -f ${PROJECT}/server/services/volcengine.js ${PROJECT}/server/config/voiceChatConfig.js ${PROJECT}/server/lib/token.js 2>&1`);
console.log("已清理远程 RTC 残留文件");
console.log("\n=== 3. 同步服务端代码 ===");
for (const { local, remote } of serverFiles) {
const localPath = join(LOCAL_BASE, local);
@@ -182,6 +195,7 @@ async function main() {
}
console.log("\n=== 4. 同步前端构建产物 ===");
await sshExec(`mkdir -p ${PROJECT}/client/dist/assets && find ${PROJECT}/client/dist/assets -maxdepth 1 -type f -delete`);
for (const { local, remote } of clientFiles) {
const localPath = join(LOCAL_BASE, local);
try {
@@ -193,6 +207,11 @@ async function main() {
}
console.log("\n=== 5. 重启 PM2 服务 ===");
console.log("\n=== 5a. 安装服务端依赖 ===");
const install = await sshExec(`cd ${PROJECT}/server && npm install --production`, 180000);
console.log(install.stdout || install.stderr);
console.log("\n=== 5b. 重启 PM2 服务 ===");
const restart = await sshExec("pm2 restart bigwo-server 2>&1");
console.log(restart.stdout);

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

6
realtime_dialog/java/.idea/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,6 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
/target/

View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="com.codeverse.userSettings.CodeverseWorkspaceAppSettingsState">
<option name="progress" value="1.0" />
</component>
</project>

13
realtime_dialog/java/.idea/compiler.xml generated Normal file
View File

@@ -0,0 +1,13 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="CompilerConfiguration">
<annotationProcessing>
<profile name="Maven default annotation processors profile" enabled="true">
<sourceOutputDir name="target/generated-sources/annotations" />
<sourceTestOutputDir name="target/generated-test-sources/test-annotations" />
<outputRelativeToContentRoot value="true" />
<module name="realtimedialog" />
</profile>
</annotationProcessing>
</component>
</project>

View File

@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Encoding">
<file url="file://$PROJECT_DIR$/src/main/java" charset="UTF-8" />
<file url="file://$PROJECT_DIR$/src/main/resources" charset="UTF-8" />
</component>
</project>

View File

@@ -0,0 +1,30 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="RemoteRepositoriesConfiguration">
<remote-repository>
<option name="id" value="bytedance-snapshots" />
<option name="name" value="bytedance-snapshots" />
<option name="url" value="https://maven.byted.org/repository/public" />
</remote-repository>
<remote-repository>
<option name="id" value="bytedance-releases" />
<option name="name" value="bytedance-releases" />
<option name="url" value="https://maven.byted.org/repository/public" />
</remote-repository>
<remote-repository>
<option name="id" value="central" />
<option name="name" value="Central Repository" />
<option name="url" value="https://maven.byted.org/repository/public" />
</remote-repository>
<remote-repository>
<option name="id" value="central" />
<option name="name" value="Maven Central repository" />
<option name="url" value="https://repo1.maven.org/maven2" />
</remote-repository>
<remote-repository>
<option name="id" value="jboss.community" />
<option name="name" value="JBoss Community repository" />
<option name="url" value="https://repository.jboss.org/nexus/content/repositories/public/" />
</remote-repository>
</component>
</project>

12
realtime_dialog/java/.idea/misc.xml generated Normal file
View File

@@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ExternalStorageConfigurationManager" enabled="true" />
<component name="MavenProjectsManager">
<option name="originalFiles">
<list>
<option value="$PROJECT_DIR$/pom.xml" />
</list>
</option>
</component>
<component name="ProjectRootManager" version="2" languageLevel="JDK_1_8" default="true" project-jdk-name="corretto-1.8" project-jdk-type="JavaSDK" />
</project>

6
realtime_dialog/java/.idea/vcs.xml generated Normal file
View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
</component>
</project>

View File

@@ -0,0 +1,63 @@
# RealtimeDialog Java客户端
## 项目简介
Java版本的RealtimeDialog客户端支持实时语音对话功能。
## 环境要求
- Java 1.8 或更高版本
- Maven 3.6 或更高版本
## 快速开始
### 1. 编译项目
```bash
cd java
mvn clean compile
```
### 2. 运行应用
#### 麦克风模式(默认)
```bash
mvn exec:java
```
#### 音频文件模式
```bash
mvn exec:java -Dexec.args="--audio=whoareyou.wav"
```
#### 文本模式
```bash
mvn exec:java -Dexec.args="--mod=text"
```
#### 指定音频格式
```bash
mvn exec:java -Dexec.args="--format=pcm_s16le"
```
### 3. 打包可执行JAR
```bash
mvn clean package
java -jar target/realtimedialog-1.0.0.jar --audio=whoareyou.wav
```
## 配置说明
在使用前,需要在`Config.java`中配置以下参数:
- `X-Api-App-ID`: 你的应用ID
- `X-Api-Access-Key`: 你的访问密钥
## 功能特性
- 支持麦克风实时语音输入
- 支持音频文件输入
- 支持文本输入模式
- 支持音频输出播放
- 支持外部RAG功能
- 支持多种音频格式pcm, pcm_s16le
## 命令行参数
- `--format`: 音频格式,默认为"pcm"
- `--audio`: 音频文件路径,如果不设置则使用麦克风输入
- `--mod`: 输入模式audio默认或text

View File

@@ -0,0 +1,56 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.volcengine</groupId>
<artifactId>realtimedialog</artifactId>
<name>RealtimeDialog Java Client</name>
<version>1.0.0</version>
<description>Java client for Volcengine RealtimeDialog service</description>
<build>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.11.0</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
<plugin>
<artifactId>maven-shade-plugin</artifactId>
<version>3.4.1</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer>
<mainClass>com.volcengine.realtimedialog.Main</mainClass>
</transformer>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>exec-maven-plugin</artifactId>
<version>3.1.0</version>
<configuration>
<mainClass>com.volcengine.realtimedialog.Main</mainClass>
</configuration>
</plugin>
</plugins>
</build>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.target>1.8</maven.compiler.target>
<java-websocket.version>1.5.3</java-websocket.version>
<jackson.version>2.15.2</jackson.version>
<maven.compiler.source>1.8</maven.compiler.source>
</properties>
</project>

View File

@@ -0,0 +1,128 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.volcengine</groupId>
<artifactId>realtimedialog</artifactId>
<version>1.0.0</version>
<packaging>jar</packaging>
<name>RealtimeDialog Java Client</name>
<description>Java client for Volcengine RealtimeDialog service</description>
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<jackson.version>2.15.2</jackson.version>
<java-websocket.version>1.5.3</java-websocket.version>
</properties>
<dependencies>
<!-- WebSocket client -->
<dependency>
<groupId>org.java-websocket</groupId>
<artifactId>Java-WebSocket</artifactId>
<version>${java-websocket.version}</version>
</dependency>
<!-- JSON processing -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<version>${jackson.version}</version>
</dependency>
<!-- Logging -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
<version>1.7.36</version>
</dependency>
<!-- Apache Commons CLI for command line parsing -->
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
<version>1.5.0</version>
</dependency>
<!-- Audio processing -->
<dependency>
<groupId>com.googlecode.soundlibs</groupId>
<artifactId>mp3spi</artifactId>
<version>1.9.5.4</version>
</dependency>
<!-- UUID generation -->
<dependency>
<groupId>com.fasterxml.uuid</groupId>
<artifactId>java-uuid-generator</artifactId>
<version>4.2.0</version>
</dependency>
<!-- Base64 encoding -->
<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
<version>1.15</version>
</dependency>
</dependencies>
<build>
<plugins>
<!-- Compiler plugin -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.11.0</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
<!-- Shade plugin for creating fat jar -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.4.1</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>com.volcengine.realtimedialog.Main</mainClass>
</transformer>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
<!-- Exec plugin for running the application -->
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>exec-maven-plugin</artifactId>
<version>3.1.0</version>
<configuration>
<mainClass>com.volcengine.realtimedialog.Main</mainClass>
</configuration>
</plugin>
</plugins>
</build>
</project>

View File

@@ -0,0 +1,114 @@
package com.volcengine.realtimedialog;
import javax.sound.sampled.*;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
public class AudioCapture {
private static final int BUFFER_SIZE = 4096;
private TargetDataLine targetLine;
private volatile boolean isCapturing = false;
private Thread captureThread;
private final BlockingQueue<byte[]> audioQueue;
public AudioCapture() {
this.audioQueue = new ArrayBlockingQueue<>(100);
}
public void startCapture() throws LineUnavailableException {
AudioFormat format = new AudioFormat(
Config.INPUT_SAMPLE_RATE,
16,
Config.CHANNELS,
true,
false // little endian
);
DataLine.Info info = new DataLine.Info(TargetDataLine.class, format);
if (!AudioSystem.isLineSupported(info)) {
throw new LineUnavailableException("音频输入设备不支持指定格式");
}
targetLine = (TargetDataLine) AudioSystem.getLine(info);
targetLine.open(format);
targetLine.start();
isCapturing = true;
captureThread = new Thread(this::captureLoop);
captureThread.setName("AudioCapture");
captureThread.start();
}
private void captureLoop() {
byte[] buffer = new byte[Config.AUDIO_CHUNK_SIZE];
while (isCapturing) {
int bytesRead = targetLine.read(buffer, 0, buffer.length);
if (bytesRead > 0) {
byte[] audioData = new byte[bytesRead];
System.arraycopy(buffer, 0, audioData, 0, bytesRead);
try {
audioQueue.put(audioData);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
break;
}
}
}
}
public byte[] readAudioData() throws InterruptedException {
return audioQueue.poll();
}
public void stopCapture() {
isCapturing = false;
if (captureThread != null) {
captureThread.interrupt();
}
if (targetLine != null) {
targetLine.stop();
targetLine.close();
}
}
public boolean isCapturing() {
return isCapturing;
}
public static byte[] readWavFile(String filePath) throws IOException {
File file = new File(filePath);
if (!file.exists()) {
throw new FileNotFoundException("音频文件不存在: " + filePath);
}
try (FileInputStream fis = new FileInputStream(file)) {
byte[] fileData = new byte[(int) file.length()];
fis.read(fileData);
// 跳过WAV文件头44字节
if (filePath.toLowerCase().endsWith(".wav") && fileData.length > Config.WAV_HEADER_SIZE) {
byte[] audioData = new byte[fileData.length - Config.WAV_HEADER_SIZE];
System.arraycopy(fileData, Config.WAV_HEADER_SIZE, audioData, 0, audioData.length);
return audioData;
}
return fileData;
}
}
public static short[] bytesToInt16Samples(byte[] data) {
short[] samples = new short[data.length / 2];
ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(samples);
return samples;
}
public static byte[] int16SamplesToBytes(short[] samples) {
byte[] bytes = new byte[samples.length * 2];
ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().put(samples);
return bytes;
}
}

View File

@@ -0,0 +1,398 @@
package com.volcengine.realtimedialog;
import java.io.IOException;
import java.net.URI;
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
public class CallManager {
private final String sessionId;
private NetClient netClient;
private AudioCapture audioCapture;
private Thread audioSendThread;
private Thread textInputThread;
private final AtomicBoolean isRunning;
private final AtomicBoolean isAudioMode;
private static final BlockingQueue<Object> queryChan = new LinkedBlockingQueue<>();
private static volatile CallManager currentInstance;
public CallManager() {
this.sessionId = Protocol.generateSessionId();
this.isRunning = new AtomicBoolean(false);
this.isAudioMode = new AtomicBoolean(true);
}
public void start() throws Exception {
System.out.println("启动实时通话管理器会话ID: " + sessionId);
// 设置当前实例
currentInstance = this;
// 建立WebSocket连接
connectWebSocket();
isRunning.set(true);
// 启动音频模式或文本模式
if (Config.mod.equals("text")) {
isAudioMode.set(false);
startTextMode();
} else {
isAudioMode.set(true);
startAudioMode();
}
// 等待运行结束
waitForCompletion();
}
private void connectWebSocket() throws Exception {
URI uri = new URI(Config.WS_URL);
Map<String, String> headers = new HashMap<>();
headers.put("X-Api-Resource-Id", Config.API_RESOURCE_ID);
headers.put("X-Api-Access-Key", Config.API_ACCESS_KEY);
headers.put("X-Api-App-Key", Config.API_APP_KEY);
headers.put("X-Api-App-ID", Config.API_APP_ID);
headers.put("X-Api-Connect-Id", sessionId);
netClient = new NetClient(uri, headers);
netClient.connectBlocking(30, TimeUnit.SECONDS);
if (!netClient.isConnected()) {
throw new IOException("WebSocket连接失败");
}
System.out.println("WebSocket连接成功");
// 发送连接开始消息事件1
startConnection();
// 发送会话开始消息事件100
startSession();
}
private void startConnection() throws Exception {
System.out.println("发送连接开始消息...");
// 使用正确的协议格式发送事件1
netClient.sendProtocolMessage(sessionId, "{}", 1);
System.out.println("连接开始消息发送完成");
}
private void startSession() throws Exception {
System.out.println("发送会话开始消息...");
RequestPayloads.StartSessionPayload payload = new RequestPayloads.StartSessionPayload();
// 根据模式设置参数
if (Config.mod.equals("text")) {
payload.dialog.extra = createExtraMap("text");
} else if (!Config.audioFilePath.isEmpty()) {
payload.dialog.extra = createExtraMap("audio_file");
} else {
payload.dialog.extra = createExtraMap("audio");
}
// 发送会话开始消息
String jsonPayload = new com.fasterxml.jackson.databind.ObjectMapper().writeValueAsString(payload);
netClient.sendProtocolMessage(sessionId, jsonPayload, 100); // 事件100的载荷
System.out.println("会话开始消息发送完成,等待服务器响应...");
// 等待会话启动响应事件150对齐Go实现
long startTime = System.currentTimeMillis();
long timeout = 30000; // 30秒超时
boolean sessionStarted = false;
while (System.currentTimeMillis() - startTime < timeout) {
Protocol.Message message = netClient.pollIncomingMessage(1, TimeUnit.SECONDS);
if (message != null && message.type == Protocol.MsgType.FULL_SERVER && message.event == 150) {
// 解析响应payload获取dialog_id
try {
String responseJson = new String(message.payload, "UTF-8");
com.fasterxml.jackson.databind.ObjectMapper mapper = new com.fasterxml.jackson.databind.ObjectMapper();
java.util.Map<String, Object> response = mapper.readValue(responseJson, java.util.Map.class);
String dialogId = (String) response.get("dialog_id");
if (dialogId != null && !dialogId.isEmpty()) {
System.out.println("会话启动成功dialog_id: " + dialogId);
sessionStarted = true;
break;
}
} catch (Exception e) {
System.err.println("解析会话启动响应失败: " + e.getMessage());
}
}
}
if (!sessionStarted) {
throw new IOException("会话启动超时或失败,未收到服务器的会话启动确认");
}
System.out.println("会话开始完成\n" + jsonPayload);
}
private Map<String, Object> createExtraMap(String inputMod) {
Map<String, Object> extra = new HashMap<>();
extra.put("strict_audit", false);
extra.put("audit_response", "抱歉这个问题我无法回答,你可以换个其他话题,我会尽力为你提供帮助。");
extra.put("input_mod", inputMod);
extra.put("model", "O");
return extra;
}
private void startAudioMode() throws Exception {
System.out.println("启动音频模式");
if (Config.audioFilePath.isEmpty()) {
// 麦克风模式
sendGreetingMessage();
startMicrophoneCapture();
startMessageReceiver();
} else {
// 音频文件模式 - 不启动麦克风,只启动消息接收
startFilePlayback();
startMessageReceiver();
}
}
private void startTextMode() throws Exception {
System.out.println("启动文本模式");
// 发送问候语对齐Golang版本
sendGreetingMessage();
// 启动文本输入线程
startTextInput();
// 启动消息接收线程
startMessageReceiver();
}
private void startMicrophoneCapture() throws Exception {
audioCapture = new AudioCapture();
audioCapture.startCapture();
// 启动音频发送线程
audioSendThread = new Thread(this::microphoneSendLoop);
audioSendThread.setName("MicrophoneAudioSend");
audioSendThread.start();
System.out.println("麦克风采集已启动");
}
private void startFilePlayback() throws Exception {
System.out.println("开始发送音频文件: " + Config.audioFilePath);
// 读取音频文件
byte[] audioData = AudioCapture.readWavFile(Config.audioFilePath);
// 启动文件发送线程
audioSendThread = new Thread(() -> fileSendLoop(audioData));
audioSendThread.setName("FileAudioSend");
audioSendThread.start();
}
private void microphoneSendLoop() {
try {
while (isRunning.get() && audioCapture.isCapturing()) {
byte[] audioData = audioCapture.readAudioData();
if (audioData != null) {
netClient.sendAudioData(sessionId, audioData);
}
// 模拟实时发送间隔
Thread.sleep(Config.AUDIO_SEND_INTERVAL);
}
} catch (Exception e) {
System.err.println("麦克风发送线程错误: " + e.getMessage());
}
}
private void fileSendLoop(byte[] audioData) {
try {
int chunkSize = Config.AUDIO_CHUNK_SIZE; // 640字节与Go实现保持一致
int totalSize = audioData.length;
int position = 0;
int chunkCount = 0;
System.out.println("开始发送音频文件,总大小: " + totalSize + " 字节, 块大小: " + chunkSize + " 字节");
while (isRunning.get() && position < totalSize) {
int remaining = totalSize - position;
int currentChunkSize = Math.min(chunkSize, remaining);
byte[] chunk = new byte[currentChunkSize];
System.arraycopy(audioData, position, chunk, 0, currentChunkSize);
System.out.println("发送音频块 #" + (++chunkCount) + ": 位置=" + position + ", 大小=" + currentChunkSize + " 字节");
netClient.sendAudioData(sessionId, chunk);
position += currentChunkSize;
// 模拟实时发送间隔 - 每20ms发送一块与Go实现保持一致
Thread.sleep(Config.AUDIO_SEND_INTERVAL);
}
System.out.println("音频文件发送完成,共发送 " + chunkCount + "");
// 发送音频结束标记 - 发送一段静音数据提示服务器音频输入结束
System.out.println("发送音频结束标记...");
byte[] silenceChunk = new byte[chunkSize]; // 静音数据
netClient.sendAudioData(sessionId, silenceChunk);
System.out.println("音频文件发送完成,等待服务器响应...");
// 文件发送完成后等待服务器通过事件359通知退出
} catch (Exception e) {
System.err.println("文件发送线程错误: " + e.getMessage());
e.printStackTrace();
}
}
private void sendGreetingMessage() throws Exception {
System.out.println("发送问候语...");
// 创建SayHello载荷对齐Golang版本使用事件300
RequestPayloads.SayHelloPayload payload = new RequestPayloads.SayHelloPayload("你好,我是豆包,有什么可以帮助你的吗?");
String jsonPayload = new com.fasterxml.jackson.databind.ObjectMapper().writeValueAsString(payload);
netClient.sendProtocolMessage(sessionId, jsonPayload, 300); // 事件300 - SayHello对齐Golang版本
System.out.println("问候语发送完成");
}
private void startTextInput() {
textInputThread = new Thread(this::textInputLoop);
textInputThread.setName("TextInput");
textInputThread.start();
}
private void textInputLoop() {
Scanner scanner = new Scanner(System.in);
System.out.println("请输入文本 (输入 'quit' 退出):");
try {
while (isRunning.get()) {
String text = scanner.nextLine();
if (text.equalsIgnoreCase("quit")) {
stop();
break;
}
if (!text.trim().isEmpty()) {
// 使用事件501发送文本查询对齐Golang版本
netClient.sendChatTextQuery(sessionId, text);
}
}
} catch (Exception e) {
System.err.println("文本输入线程错误: " + e.getMessage());
}
}
private void startMessageReceiver() {
Thread receiverThread = new Thread(this::messageReceiveLoop);
receiverThread.setName("MessageReceiver");
receiverThread.start();
}
private void messageReceiveLoop() {
try {
while (isRunning.get()) {
Protocol.Message message = netClient.pollIncomingMessage(1, TimeUnit.SECONDS);
if (message != null) {
// 消息已在NetClient中处理这里可以添加额外的逻辑
if (message.type == Protocol.MsgType.ERROR) {
String error = new String(message.payload);
System.err.println("服务器错误: " + error);
}
}
}
} catch (Exception e) {
System.err.println("消息接收线程错误: " + e.getMessage());
}
}
private void waitForCompletion() throws InterruptedException {
while (isRunning.get()) {
Thread.sleep(100);
// 对于音频文件模式,文件发送完成后等待服务器响应
if (isAudioMode.get() && !Config.audioFilePath.isEmpty() && audioSendThread != null && !audioSendThread.isAlive()) {
// 音频文件已发送完成,继续等待服务器响应
System.out.println("音频文件发送完成,等待服务器响应...");
// 不退出,继续等待消息接收线程处理服务器响应
// 服务器会通过事件359通知可以退出
}
}
}
public void stop() {
System.out.println("停止通话管理器");
isRunning.set(false);
try {
// 发送会话结束消息事件102- 参考Go实现
if (netClient != null && netClient.isConnected()) {
System.out.println("发送会话结束消息...");
finishSession();
Thread.sleep(100); // 给服务器处理时间
}
} catch (Exception e) {
System.err.println("发送会话结束消息失败: " + e.getMessage());
}
// 停止音频采集
if (audioCapture != null) {
audioCapture.stopCapture();
}
// 关闭WebSocket连接并打印logid
if (netClient != null) {
String logid = netClient.getLogid();
if (logid != null && !logid.isEmpty()) {
System.out.println("通话结束logid: " + logid);
}
netClient.close();
}
// 等待线程结束
try {
if (audioSendThread != null) {
audioSendThread.join(1000);
}
if (textInputThread != null) {
textInputThread.join(1000);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
System.out.println("通话管理器已停止");
}
private void finishSession() throws Exception {
if (netClient != null && sessionId != null) {
netClient.sendProtocolMessage(sessionId, "{}", 102); // 事件102 - FinishSession
System.out.println("会话结束消息已发送");
}
}
// 通知用户查询事件
public static void notifyUserQuery() {
queryChan.offer(new Object());
}
// 从处理器停止CallManager
public static void stopFromHandler() {
if (currentInstance != null) {
currentInstance.stop();
}
}
}

View File

@@ -0,0 +1,59 @@
package com.volcengine.realtimedialog;
public class Config {
// WebSocket连接配置
public static final String WS_URL = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue";
public static final String API_RESOURCE_ID = "volc.speech.dialog";
// 用户需要配置的参数
public static String API_APP_ID = "";
public static String API_ACCESS_KEY = "";
public static String API_APP_KEY = "PlgvMymc7f3tQnJ6";
// 音频参数配置
public static final int INPUT_SAMPLE_RATE = 16000;
public static final int OUTPUT_SAMPLE_RATE = 24000;
public static final int CHANNELS = 1;
public static final int INPUT_FRAMES_PER_BUFFER = 160;
public static final int OUTPUT_FRAMES_PER_BUFFER = 512;
public static final int BUFFER_SECONDS = 100;
// 音频格式
public static final String DEFAULT_PCM = "pcm";
public static final String PCM_S16LE = "pcm_s16le";
// TTS配置
public static final String DEFAULT_SPEAKER = "zh_female_vv_jupiter_bigtts";
// 网络配置
public static final int AUDIO_CHUNK_SIZE = 640; // 字节对应20ms音频数据
public static final long AUDIO_SEND_INTERVAL = 20; // 毫秒
// WAV文件配置
public static final int WAV_HEADER_SIZE = 44; // WAV文件头大小
// 命令行参数默认值
public static String audioFilePath = "";
public static String mod = "audio";
public static String pcmFormat = PCM_S16LE;
public static void setAppId(String appId) {
API_APP_ID = appId;
}
public static void setAccessKey(String accessKey) {
API_ACCESS_KEY = accessKey;
}
public static void setAudioFilePath(String path) {
audioFilePath = path;
}
public static void setMod(String mode) {
mod = mode;
}
public static void setPcmFormat(String format) {
pcmFormat = format;
}
}

View File

@@ -0,0 +1,179 @@
package com.volcengine.realtimedialog;
import org.apache.commons.cli.*;
public class Main {
public static void main(String[] args) {
// 解析命令行参数
CommandLine cmd = parseCommandLine(args);
if (cmd == null) {
System.exit(1);
}
// 应用配置
applyConfiguration(cmd);
// 验证必要的配置
if (!validateConfiguration()) {
System.exit(1);
}
// 启动通话管理器
CallManager callManager = new CallManager();
try {
// 添加关闭钩子
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
System.out.println("正在关闭应用...");
callManager.stop();
}));
// 开始通话
callManager.start();
System.out.println("通话结束");
} catch (Exception e) {
System.err.println("运行错误: " + e.getMessage());
e.printStackTrace();
System.exit(1);
}
}
private static CommandLine parseCommandLine(String[] args) {
Options options = new Options();
// 音频文件路径
Option audioOption = Option.builder("a")
.longOpt("audio")
.desc("音频文件路径,如果不设置则使用麦克风输入")
.hasArg()
.argName("FILE")
.build();
options.addOption(audioOption);
// 输入模式
Option modOption = Option.builder("m")
.longOpt("mod")
.desc("输入模式audio默认或text")
.hasArg()
.argName("MODE")
.build();
options.addOption(modOption);
// 音频格式
Option formatOption = Option.builder("f")
.longOpt("format")
.desc("音频格式默认为pcm可选pcm_s16le")
.hasArg()
.argName("FORMAT")
.build();
options.addOption(formatOption);
// 应用ID
Option appIdOption = Option.builder()
.longOpt("app_id")
.desc("应用ID如果不设置则使用Config中的默认值")
.hasArg()
.argName("APP_ID")
.build();
options.addOption(appIdOption);
// 访问密钥
Option accessKeyOption = Option.builder()
.longOpt("access_key")
.desc("访问密钥如果不设置则使用Config中的默认值")
.hasArg()
.argName("ACCESS_KEY")
.build();
options.addOption(accessKeyOption);
// 帮助
Option helpOption = Option.builder("h")
.longOpt("help")
.desc("显示帮助信息")
.build();
options.addOption(helpOption);
CommandLineParser parser = new DefaultParser();
HelpFormatter formatter = new HelpFormatter();
try {
CommandLine cmd = parser.parse(options, args);
if (cmd.hasOption("help")) {
formatter.printHelp("java -jar realtimelog-1.0.0.jar", options);
return null;
}
return cmd;
} catch (ParseException e) {
System.err.println("参数解析错误: " + e.getMessage());
formatter.printHelp("java -jar realtimelog-1.0.0.jar", options);
return null;
}
}
private static void applyConfiguration(CommandLine cmd) {
// 应用音频文件路径
if (cmd.hasOption("audio")) {
Config.setAudioFilePath(cmd.getOptionValue("audio"));
}
// 应用输入模式
if (cmd.hasOption("mod")) {
String mode = cmd.getOptionValue("mod");
if (!mode.equals("audio") && !mode.equals("text")) {
System.err.println("错误mod参数必须是audio或text");
System.exit(1);
}
Config.setMod(mode);
}
// 应用音频格式
if (cmd.hasOption("format")) {
String format = cmd.getOptionValue("format");
if (!format.equals("pcm") && !format.equals("pcm_s16le")) {
System.err.println("错误format参数必须是pcm或pcm_s16le");
System.exit(1);
}
Config.setPcmFormat(format);
}
// 应用应用ID
if (cmd.hasOption("app_id")) {
Config.setAppId(cmd.getOptionValue("app_id"));
}
// 应用访问密钥
if (cmd.hasOption("access_key")) {
Config.setAccessKey(cmd.getOptionValue("access_key"));
}
}
private static boolean validateConfiguration() {
// 检查必要的配置
if (Config.API_APP_ID.equals("your_app_id")) {
System.err.println("错误必须设置应用ID");
System.err.println("请在Config.java中设置API_APP_ID或使用--app_id参数");
return false;
}
if (Config.API_ACCESS_KEY.equals("your_access_key")) {
System.err.println("错误:必须设置访问密钥");
System.err.println("请在Config.java中设置API_ACCESS_KEY或使用--access_key参数");
return false;
}
// 检查音频文件是否存在
if (!Config.audioFilePath.isEmpty()) {
java.io.File file = new java.io.File(Config.audioFilePath);
if (!file.exists()) {
System.err.println("错误:音频文件不存在: " + Config.audioFilePath);
return false;
}
}
return true;
}
}

View File

@@ -0,0 +1,402 @@
package com.volcengine.realtimedialog;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft_6455;
import org.java_websocket.handshake.ServerHandshake;
import javax.sound.sampled.*;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
public class NetClient extends WebSocketClient {
private final BlockingQueue<Protocol.Message> incomingMessages;
private volatile boolean isConnected = false;
private volatile boolean shouldStop = false;
private SourceDataLine audioOutputLine;
private Thread audioPlaybackThread;
private final BlockingQueue<byte[]> audioQueue;
private volatile String logid; // 保存logid用于通话结束时打印
public NetClient(URI serverUri, Map<String, String> headers) {
super(serverUri, new Draft_6455(), headers, 0);
this.incomingMessages = new LinkedBlockingQueue<>();
this.audioQueue = new LinkedBlockingQueue<>();
// 只在非录音文件模式下初始化音频输出(文本模式需要播放器)
if (!isAudioFileInput()) {
initializeAudioOutput();
}
}
private void initializeAudioOutput() {
try {
AudioFormat format = new AudioFormat(
Config.OUTPUT_SAMPLE_RATE,
16,
Config.CHANNELS,
true,
false // little endian
);
DataLine.Info info = new DataLine.Info(SourceDataLine.class, format);
if (!AudioSystem.isLineSupported(info)) {
System.err.println("不支持音频输出格式");
return;
}
audioOutputLine = (SourceDataLine) AudioSystem.getLine(info);
audioOutputLine.open(format);
audioOutputLine.start();
// 启动音频播放线程
audioPlaybackThread = new Thread(this::audioPlaybackLoop);
audioPlaybackThread.setName("AudioPlayback");
audioPlaybackThread.start();
} catch (LineUnavailableException e) {
System.err.println("音频输出初始化失败: " + e.getMessage());
}
}
// 播放状态枚举
private enum PlaybackState {
IDLE, // 空闲状态
PLAYING, // 正在播放
WAITING_DATA // 等待数据
}
private void audioPlaybackLoop() {
PlaybackState state = PlaybackState.IDLE;
int emptyCount = 0;
final int maxEmptyCount = 20; // 1秒没有数据
final boolean isTextMode = Config.mod.equals("text");
final boolean isAudioFileMode = isAudioFileInput();
System.out.println("音频播放线程启动 - 模式: " + Config.mod +
", 文本模式: " + isTextMode +
", 音频文件模式: " + isAudioFileMode);
while (!shouldStop) {
try {
byte[] audioData = audioQueue.poll(50, TimeUnit.MILLISECONDS);
if (audioData != null && audioOutputLine != null) {
// 状态转换:接收到数据 -> 播放状态
if (state != PlaybackState.PLAYING) {
state = PlaybackState.PLAYING;
if (!isTextMode && !isAudioFileMode) {
System.out.println("🎵 开始播放音频...");
}
}
// 写入音频数据到播放设备
audioOutputLine.write(audioData, 0, audioData.length);
emptyCount = 0;
// 调试信息控制
if (!isTextMode && !isAudioFileMode && audioData.length > 0) {
System.out.println("播放音频数据: " + audioData.length + " 字节");
}
} else {
// 没有数据到达
if (state == PlaybackState.PLAYING) {
// 从播放状态转换到等待数据状态
state = PlaybackState.WAITING_DATA;
emptyCount = 0;
} else if (state == PlaybackState.WAITING_DATA) {
emptyCount++;
if (emptyCount > maxEmptyCount) {
// 转换到空闲状态
state = PlaybackState.IDLE;
if (!isTextMode && !isAudioFileMode) {
System.out.println("⏸️ 音频播放暂停,等待数据...");
}
emptyCount = 0;
}
}
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
break;
} catch (Exception e) {
System.err.println("❌ 音频播放错误: " + e.getMessage());
e.printStackTrace();
state = PlaybackState.IDLE;
}
}
System.out.println("🛑 音频播放线程结束");
}
@Override
public void onOpen(ServerHandshake handshake) {
System.out.println("WebSocket连接已建立");
isConnected = true;
// 获取并保存logid
logid = handshake.getFieldValue("X-Tt-Logid");
if (logid != null && !logid.isEmpty()) {
System.out.println("连接建立logid: " + logid);
}
}
@Override
public void onMessage(String message) {
System.out.println("收到文本消息: " + message);
}
@Override
public void onMessage(ByteBuffer bytes) {
try {
byte[] data = new byte[bytes.remaining()];
bytes.get(data);
System.out.println("收到WebSocket二进制消息长度: " + data.length + " 字节");
System.out.println("原始数据前20字节: " + bytesToHex(data, Math.min(20, data.length)));
try {
Protocol.Message message = Protocol.unmarshal(data);
System.out.println("解析消息成功 - 类型: " + message.type + ", 事件ID: " + message.event + ", 会话ID: " + message.sessionId);
// 直接使用ProtocolV2.Message不再转换到旧格式
switch (message.type) {
case FULL_SERVER:
handleFullServerMessage(message);
break;
case AUDIO_ONLY_SERVER:
handleAudioOnlyServerMessage(message);
break;
case ERROR:
handleErrorMessage(message);
break;
default:
System.err.println("未知消息类型: " + message.type);
}
incomingMessages.offer(message);
} catch (IOException e) {
System.err.println("消息解析失败: " + e.getMessage());
System.err.println("尝试解析为文本消息...");
try {
String text = new String(data, "UTF-8");
System.err.println("文本内容: " + text);
} catch (Exception textEx) {
System.err.println("也无法解析为文本: " + textEx.getMessage());
}
}
} catch (Exception e) {
System.err.println("处理消息时出错: " + e.getMessage());
e.printStackTrace();
}
}
private String bytesToHex(byte[] bytes, int length) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < length; i++) {
sb.append(String.format("%02X ", bytes[i]));
}
return sb.toString();
}
private void handleFullServerMessage(Protocol.Message message) {
ServerResponseHandler.handleFullServerMessage(this, message);
}
private void handleAudioOnlyServerMessage(Protocol.Message message) {
ServerResponseHandler.handleAudioOnlyServerMessage(this, message);
}
private void handleErrorMessage(Protocol.Message message) {
ServerResponseHandler.handleErrorMessage(message);
}
// 播放音频数据 - 对齐Golang实现简化音频处理
public void playAudioData(byte[] audioData) {
// 录音文件模式下不播放音频
if (isAudioFileInput()) {
return;
}
try {
if (audioData == null || audioData.length == 0) {
return;
}
System.out.println("播放音频数据: " + audioData.length + " 字节");
// 根据配置格式处理音频数据
switch (Config.pcmFormat) {
case Config.PCM_S16LE:
// s16le格式直接播放
if (audioData.length % 2 != 0) {
System.err.println("s16le音频数据长度不是2的倍数: " + audioData.length);
return;
}
audioQueue.offer(audioData);
break;
case Config.DEFAULT_PCM:
// f32le格式需要转换为s16le
if (audioData.length % 4 != 0) {
System.err.println("f32le音频数据长度不是4的倍数: " + audioData.length);
return;
}
int sampleCount = audioData.length / 4;
short[] samples = new short[sampleCount];
ByteBuffer buffer = ByteBuffer.wrap(audioData).order(ByteOrder.LITTLE_ENDIAN);
for (int i = 0; i < sampleCount; i++) {
float sample = buffer.getFloat();
// 将float转换为short确保范围正确
samples[i] = (short) Math.max(-32768, Math.min(32767, sample * 32767.0f));
}
// 转换为字节数组并播放
byte[] s16Data = AudioCapture.int16SamplesToBytes(samples);
audioQueue.offer(s16Data);
break;
}
} catch (Exception e) {
System.err.println("播放音频数据失败: " + e.getMessage());
e.printStackTrace();
}
}
@Override
public void onClose(int code, String reason, boolean remote) {
System.out.println("WebSocket连接已关闭. 代码: " + code + ", 原因: " + reason);
// 打印logid
if (logid != null && !logid.isEmpty()) {
System.out.println("连接关闭logid: " + logid);
}
isConnected = false;
cleanup();
}
@Override
public void onError(Exception ex) {
System.err.println("WebSocket错误: " + ex.getMessage());
ex.printStackTrace();
}
// 检查是否为录音文件输入模式
private boolean isAudioFileInput() {
return !Config.audioFilePath.isEmpty();
}
public boolean isConnected() {
return isConnected;
}
public String getLogid() {
return logid;
}
public Protocol.Message pollIncomingMessage(long timeout, TimeUnit unit) throws InterruptedException {
return incomingMessages.poll(timeout, unit);
}
public void sendAudioData(String sessionId, byte[] audioData) throws IOException {
if (!isConnected) {
throw new IOException("WebSocket未连接");
}
try {
byte[] message = Protocol.createAudioMessage(sessionId, audioData);
send(message);
} catch (Exception e) {
throw new IOException("发送音频消息失败: " + e.getMessage(), e);
}
}
public void sendTextMessage(String sessionId, String text) throws IOException {
if (!isConnected) {
throw new IOException("WebSocket未连接");
}
byte[] message = Protocol.createFullClientMessage(sessionId, text);
send(message);
}
public void sendProtocolMessage(String sessionId, String text, int eventId) throws IOException {
if (!isConnected) {
throw new IOException("WebSocket未连接");
}
try {
byte[] messageBytes;
if (eventId == 1) {
messageBytes = Protocol.createStartConnectionMessage();
} else if (eventId == 100) {
messageBytes = Protocol.createStartSessionMessage(sessionId, text);
} else {
// 创建带特定事件ID的消息
Protocol.Message message = new Protocol.Message();
message.type = Protocol.MsgType.FULL_CLIENT;
message.typeFlag = Protocol.MSG_TYPE_FLAG_WITH_EVENT;
message.event = eventId;
message.sessionId = sessionId;
message.payload = text.getBytes("UTF-8");
messageBytes = Protocol.marshal(message);
}
send(messageBytes);
} catch (Exception e) {
throw new IOException("发送协议消息失败: " + e.getMessage(), e);
}
}
public void sendChatTextQuery(String sessionId, String text) throws IOException {
if (!isConnected) {
throw new IOException("WebSocket未连接");
}
try {
// 创建ChatTextQuery消息事件501
RequestPayloads.ChatTextQueryPayload payload = new RequestPayloads.ChatTextQueryPayload(text);
String jsonPayload = new com.fasterxml.jackson.databind.ObjectMapper().writeValueAsString(payload);
Protocol.Message message = new Protocol.Message();
message.type = Protocol.MsgType.FULL_CLIENT;
message.typeFlag = Protocol.MSG_TYPE_FLAG_WITH_EVENT;
message.event = 501; // ChatTextQuery事件
message.sessionId = sessionId;
message.payload = jsonPayload.getBytes("UTF-8");
byte[] messageBytes = Protocol.marshal(message);
send(messageBytes);
System.out.println("发送ChatTextQuery消息成功: " + text);
} catch (Exception e) {
throw new IOException("发送ChatTextQuery消息失败: " + e.getMessage(), e);
}
}
private void cleanup() {
shouldStop = true;
if (audioPlaybackThread != null) {
audioPlaybackThread.interrupt();
}
if (audioOutputLine != null) {
audioOutputLine.drain();
audioOutputLine.stop();
audioOutputLine.close();
}
// 保存音频到文件
ServerResponseHandler.saveAudioToPCMFile("output.pcm");
}
}

View File

@@ -0,0 +1,366 @@
package com.volcengine.realtimedialog;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
public class Protocol {
private static final ObjectMapper objectMapper = new ObjectMapper();
// 消息类型
public enum MsgType {
INVALID(0),
FULL_CLIENT(1),
AUDIO_ONLY_CLIENT(2),
FULL_SERVER(9),
AUDIO_ONLY_SERVER(11),
FRONT_END_RESULT_SERVER(12),
ERROR(15);
private final int value;
MsgType(int value) {
this.value = value;
}
public int getValue() {
return value;
}
public static MsgType fromBits(int bits) {
for (MsgType type : values()) {
if (type.value == bits) {
return type;
}
}
return INVALID;
}
}
// 消息类型标志位
public static final int MSG_TYPE_FLAG_NO_SEQ = 0;
public static final int MSG_TYPE_FLAG_POSITIVE_SEQ = 0b1;
public static final int MSG_TYPE_FLAG_LAST_NO_SEQ = 0b10;
public static final int MSG_TYPE_FLAG_NEGATIVE_SEQ = 0b11;
public static final int MSG_TYPE_FLAG_WITH_EVENT = 0b100;
// 版本和头部大小
public static final int VERSION_1 = 0x10;
public static final int HEADER_SIZE_4 = 0x1;
// 序列化方法
public static final int SERIALIZATION_RAW = 0;
public static final int SERIALIZATION_JSON = 0b1 << 4;
// 压缩方法
public static final int COMPRESSION_NONE = 0;
public static class Message {
public MsgType type;
public int typeFlag;
public int event;
public String sessionId;
public String connectId;
public int sequence;
public long errorCode;
public byte[] payload;
public Message() {
this.type = MsgType.INVALID;
}
}
public static byte[] marshal(Message msg) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
// 构建头部
int versionAndHeaderSize = VERSION_1 | HEADER_SIZE_4;
dos.writeByte(versionAndHeaderSize);
// 消息类型和标志
int typeAndFlag = (msg.type.getValue() << 4) | (msg.typeFlag & 0x0F);
dos.writeByte(typeAndFlag);
// 序列化和压缩
int serializationAndCompression = SERIALIZATION_JSON | COMPRESSION_NONE;
dos.writeByte(serializationAndCompression);
// 保留字节
dos.writeByte(0);
// 根据消息类型写入数据
List<WriteFunc> writers = getWriters(msg);
for (WriteFunc writer : writers) {
writer.write(dos, msg);
}
return baos.toByteArray();
}
public static Message unmarshal(byte[] data) throws IOException {
if (data.length < 4) {
throw new IOException("数据长度不足");
}
ByteBuffer buf = ByteBuffer.wrap(data).order(ByteOrder.BIG_ENDIAN);
Message msg = new Message();
// 读取头部
int versionAndHeaderSize = buf.get() & 0xFF;
int typeAndFlag = buf.get() & 0xFF;
int serializationAndCompression = buf.get() & 0xFF;
int reserved = buf.get() & 0xFF;
// 解析消息类型
int msgTypeBits = (typeAndFlag >> 4) & 0x0F;
msg.type = MsgType.fromBits(msgTypeBits);
msg.typeFlag = typeAndFlag & 0x0F;
// 根据消息类型读取数据
List<ReadFunc> readers = getReaders(msg);
for (ReadFunc reader : readers) {
reader.read(buf, msg);
}
return msg;
}
private interface WriteFunc {
void write(DataOutputStream dos, Message msg) throws IOException;
}
private interface ReadFunc {
void read(ByteBuffer buf, Message msg) throws IOException;
}
private static List<WriteFunc> getWriters(Message msg) {
List<WriteFunc> writers = new ArrayList<>();
// 事件ID
if (containsEvent(msg.typeFlag)) {
writers.add((dos, m) -> dos.writeInt(m.event));
}
// 会话ID
if (shouldWriteSessionId(msg)) {
writers.add((dos, m) -> {
byte[] sessionIdBytes = m.sessionId.getBytes("UTF-8");
dos.writeInt(sessionIdBytes.length);
dos.write(sessionIdBytes);
});
}
// 连接ID
if (shouldWriteConnectId(msg)) {
writers.add((dos, m) -> {
byte[] connectIdBytes = m.connectId.getBytes("UTF-8");
dos.writeInt(connectIdBytes.length);
dos.write(connectIdBytes);
});
}
// 序列号
if (containsSequence(msg.typeFlag)) {
writers.add((dos, m) -> dos.writeInt(m.sequence));
}
// 错误码
if (msg.type == MsgType.ERROR) {
writers.add((dos, m) -> dos.writeInt((int) m.errorCode));
}
// 载荷
writers.add((dos, m) -> {
if (m.payload != null) {
dos.writeInt(m.payload.length);
dos.write(m.payload);
} else {
dos.writeInt(0);
}
});
return writers;
}
private static List<ReadFunc> getReaders(Message msg) {
List<ReadFunc> readers = new ArrayList<>();
// 事件ID
if (containsEvent(msg.typeFlag)) {
readers.add((buf, m) -> m.event = buf.getInt());
}
// 会话ID
if (shouldReadSessionId(msg)) {
readers.add((buf, m) -> {
int size = buf.getInt();
if (size > 0) {
byte[] bytes = new byte[size];
buf.get(bytes);
m.sessionId = new String(bytes, "UTF-8");
}
});
}
// 连接ID
if (shouldReadConnectId(msg)) {
readers.add((buf, m) -> {
int size = buf.getInt();
if (size > 0) {
byte[] bytes = new byte[size];
buf.get(bytes);
m.connectId = new String(bytes, "UTF-8");
}
});
}
// 序列号
if (containsSequence(msg.typeFlag)) {
readers.add((buf, m) -> m.sequence = buf.getInt());
}
// 错误码
if (msg.type == MsgType.ERROR) {
readers.add((buf, m) -> m.errorCode = buf.getInt() & 0xFFFFFFFFL);
}
// 载荷
readers.add((buf, m) -> {
int size = buf.getInt();
if (size > 0) {
m.payload = new byte[size];
buf.get(m.payload);
}
});
return readers;
}
private static boolean containsEvent(int typeFlag) {
return (typeFlag & MSG_TYPE_FLAG_WITH_EVENT) == MSG_TYPE_FLAG_WITH_EVENT;
}
private static boolean containsSequence(int typeFlag) {
return (typeFlag & MSG_TYPE_FLAG_POSITIVE_SEQ) == MSG_TYPE_FLAG_POSITIVE_SEQ ||
(typeFlag & MSG_TYPE_FLAG_NEGATIVE_SEQ) == MSG_TYPE_FLAG_NEGATIVE_SEQ;
}
private static boolean shouldWriteSessionId(Message msg) {
// 根据Go版本的逻辑某些事件不需要会话ID
return containsEvent(msg.typeFlag) &&
msg.event != 1 && msg.event != 2 && msg.event != 50 && msg.event != 51 && msg.event != 52;
}
private static boolean shouldReadSessionId(Message msg) {
return containsEvent(msg.typeFlag) &&
msg.event != 1 && msg.event != 2 && msg.event != 50 && msg.event != 51 && msg.event != 52;
}
private static boolean shouldWriteConnectId(Message msg) {
return containsEvent(msg.typeFlag) && (msg.event == 50 || msg.event == 51 || msg.event == 52);
}
private static boolean shouldReadConnectId(Message msg) {
return containsEvent(msg.typeFlag) && (msg.event == 50 || msg.event == 51 || msg.event == 52);
}
// 辅助方法
public static byte[] createStartConnectionMessage() throws IOException {
Message msg = new Message();
msg.type = MsgType.FULL_CLIENT;
msg.typeFlag = MSG_TYPE_FLAG_WITH_EVENT;
msg.event = 1;
msg.payload = "{}".getBytes("UTF-8");
return marshal(msg);
}
public static byte[] createStartSessionMessage(String sessionId, String payload) throws IOException {
Message msg = new Message();
msg.type = MsgType.FULL_CLIENT;
msg.typeFlag = MSG_TYPE_FLAG_WITH_EVENT;
msg.event = 100;
msg.sessionId = sessionId;
msg.payload = payload.getBytes("UTF-8");
return marshal(msg);
}
public static byte[] createAudioMessage(String sessionId, byte[] audioData) throws IOException {
Message msg = new Message();
msg.type = MsgType.AUDIO_ONLY_CLIENT;
msg.typeFlag = MSG_TYPE_FLAG_WITH_EVENT;
msg.event = 200; // 音频事件 - 完全对齐Go版本
msg.sessionId = sessionId;
msg.payload = audioData;
return marshalRawAudio(msg);
}
// 专门用于音频消息的方法 - 使用原始序列化
private static byte[] marshalRawAudio(Message message) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
// 构建头部 - 使用原始序列化
int versionAndHeaderSize = VERSION_1 | HEADER_SIZE_4;
dos.writeByte(versionAndHeaderSize);
// 消息类型和标志
int typeAndFlag = (message.type.getValue() << 4) | (message.typeFlag & 0x0F);
dos.writeByte(typeAndFlag);
// 序列化和压缩 - 使用原始数据
int serializationAndCompression = SERIALIZATION_RAW | COMPRESSION_NONE;
dos.writeByte(serializationAndCompression);
// 保留字节
dos.writeByte(0);
// 事件ID
if (containsEvent(message.typeFlag)) {
dos.writeInt(message.event);
}
// 会话ID
if (shouldWriteSessionId(message)) {
byte[] sessionIdBytes = message.sessionId.getBytes("UTF-8");
dos.writeInt(sessionIdBytes.length);
dos.write(sessionIdBytes);
}
// 载荷
if (message.payload != null) {
dos.writeInt(message.payload.length);
dos.write(message.payload);
} else {
dos.writeInt(0);
}
return baos.toByteArray();
}
public static String generateSessionId() {
return UUID.randomUUID().toString();
}
// 创建FullClient消息 - 用于兼容旧代码
public static byte[] createFullClientMessage(String sessionId, String text) throws IOException {
ObjectNode root = objectMapper.createObjectNode();
root.put("session_id", sessionId);
root.put("text", text);
root.put("speaker", Config.DEFAULT_SPEAKER);
Message message = new Message();
message.type = MsgType.FULL_CLIENT;
message.typeFlag = MSG_TYPE_FLAG_WITH_EVENT;
message.sessionId = sessionId;
message.payload = objectMapper.writeValueAsBytes(root);
return marshal(message);
}
}

View File

@@ -0,0 +1,98 @@
package com.volcengine.realtimedialog;
import com.fasterxml.jackson.annotation.JsonInclude;
import java.util.HashMap;
import java.util.Map;
@JsonInclude(JsonInclude.Include.NON_NULL)
public class RequestPayloads {
// StartSession请求载荷
public static class StartSessionPayload {
public ASRPayload asr;
public TTSPayload tts;
public DialogPayload dialog;
public StartSessionPayload() {
this.asr = new ASRPayload();
this.tts = new TTSPayload();
this.dialog = new DialogPayload();
}
}
public static class ASRPayload {
public Map<String, Object> extra = new HashMap<>();
}
public static class TTSPayload {
public String speaker = Config.DEFAULT_SPEAKER;
public AudioConfig audio_config = new AudioConfig();
}
public static class AudioConfig {
public int channel = 1;
public String format = Config.pcmFormat;
public int sample_rate = Config.OUTPUT_SAMPLE_RATE;
}
public static class DialogPayload {
public String dialog_id = "";
public String bot_name = "豆包";
public String system_role = "你使用活泼灵动的女声,性格开朗,热爱生活。";
public String speaking_style = "你的说话风格简洁明了,语速适中,语调自然。";
public LocationInfo location = new LocationInfo();
public Map<String, Object> extra = new HashMap<>();
}
public static class LocationInfo {
public double longitude = 0.0;
public double latitude = 0.0;
public String city = "北京";
public String country = "中国";
public String province = "北京";
public String district = "";
public String town = "";
public String country_code = "CN";
public String address = "";
}
// SayHello请求载荷
public static class SayHelloPayload {
public String content;
public SayHelloPayload(String content) {
this.content = content;
}
}
// ChatTextQuery请求载荷
public static class ChatTextQueryPayload {
public String content;
public ChatTextQueryPayload(String content) {
this.content = content;
}
}
// ChatTTSText请求载荷
public static class ChatTTSTextPayload {
public boolean start;
public boolean end;
public String content;
public ChatTTSTextPayload(boolean start, boolean end, String content) {
this.start = start;
this.end = end;
this.content = content;
}
}
// ChatRAGText请求载荷
public static class ChatRAGTextPayload {
public String external_rag;
public ChatRAGTextPayload(String externalRAG) {
this.external_rag = externalRAG;
}
}
}

View File

@@ -0,0 +1,342 @@
package com.volcengine.realtimedialog;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
public class ServerResponseHandler {
private static final ObjectMapper objectMapper = new ObjectMapper();
private static final int SAMPLE_RATE = 24000;
private static final int CHANNELS = 1;
private static final int BUFFER_SECONDS = 100;
// 音频缓冲区
private static final List<Float> audioBuffer = Collections.synchronizedList(new ArrayList<>());
private static final List<Short> s16Buffer = Collections.synchronizedList(new ArrayList<>());
private static final List<Byte> audioData = Collections.synchronizedList(new ArrayList<>());
// 状态标志
private static final AtomicBoolean isSendingChatTTSText = new AtomicBoolean(false);
private static final AtomicBoolean isUserQuerying = new AtomicBoolean(false);
private static final Object sayHelloOverLock = new Object();
private static volatile boolean sayHelloOver = false;
private static final Object firstMsgLock = new Object();
private static volatile boolean firstMsgProcessed = false;
// 外部RAG数据结构
public static class RAGObject {
public String title;
public String content;
public RAGObject(String title, String content) {
this.title = title;
this.content = content;
}
}
// ChatTTSText载荷
public static class ChatTTSTextPayload {
public boolean start;
public boolean end;
public String content;
public ChatTTSTextPayload(boolean start, boolean end, String content) {
this.start = start;
this.end = end;
this.content = content;
}
}
// ChatRAGText载荷
public static class ChatRAGTextPayload {
public String externalRAG;
public ChatRAGTextPayload(String externalRAG) {
this.externalRAG = externalRAG;
}
}
// 消息处理
public static void handleFullServerMessage(NetClient netClient, Protocol.Message message) {
try {
String jsonStr = new String(message.payload);
System.out.println("📨 收到服务器完整消息 (event=" + message.event + ", session_id=" + message.sessionId + "): " + jsonStr);
// 事件处理
System.out.println("🔍 处理事件 ID: " + message.event);
switch (message.event) {
case 50: // ConnectionStarted
System.out.println("✅ 连接已建立");
return;
case 150: // SessionStarted
System.out.println("✅ 会话已开始");
return;
case 152: // session finished event
case 153: // session finished event
System.out.println("🏁 会话结束事件");
// 通知CallManager停止
CallManager.stopFromHandler();
return;
case 359: // 首次响应事件
System.out.println("🎯 收到事件359音频文件模式: " + isAudioFileInput());
if (isAudioFileInput()) {
System.out.println("🎉 音频文件模式收到首次响应,保存音频并退出...");
// 音频文件模式下收到事件359后保存音频并退出
saveAudioToPCMFile("output.pcm");
CallManager.stopFromHandler();
return;
}
// 文本模式下收到事件359后提示用户输入
if (Config.mod.equals("text")) {
System.out.println("💬 请输入内容");
} else {
// 音频模式下,标记首次消息已处理
synchronized (firstMsgLock) {
if (!firstMsgProcessed) {
firstMsgProcessed = true;
synchronized (sayHelloOverLock) {
sayHelloOver = true;
sayHelloOverLock.notifyAll();
}
}
}
}
break;
case 300: // SayHello响应事件对齐Golang版本
System.out.println("🎯 收到SayHello响应事件");
if (Config.mod.equals("text")) {
System.out.println("💬 问候语已发送,请输入内容");
}
break;
case 450: // ASR info event, clear audio buffer
// 清空本地音频缓存,等待接收下一轮的音频
synchronized (audioData) {
audioData.clear();
}
synchronized (audioBuffer) {
audioBuffer.clear();
}
// 用户说话了不需要触发连续SayHello引导用户交互了
CallManager.notifyUserQuery();
isUserQuerying.set(true);
break;
case 350: // 发送ChatTTSText请求事件之后收到tts_type为chat_tts_text的事件
if (isSendingChatTTSText.get()) {
// 解析JSON数据
JsonNode jsonData = objectMapper.readTree(message.payload);
String ttsType = jsonData.get("tts_type").asText();
// 一种简单方式清空本地闲聊音频
if (Arrays.asList("chat_tts_text", "external_rag").contains(ttsType)) {
synchronized (audioData) {
audioData.clear();
}
synchronized (audioBuffer) {
audioBuffer.clear();
}
isSendingChatTTSText.set(false);
}
}
break;
case 459:
isUserQuerying.set(false);
// 概率触发发送ChatTTSText请求
if (new Random().nextInt(100000) % 1000 == 0) {
new Thread(() -> {
try {
isSendingChatTTSText.set(true);
System.out.println("hit ChatTTSText event, start sending...");
// 发送ChatTTSText请求
sendChatTTSText(netClient, message.sessionId, new ChatTTSTextPayload(
true, false, "这是查询到外部数据之前的安抚话术。"
));
sendChatTTSText(netClient, message.sessionId, new ChatTTSTextPayload(
false, true, ""
));
// 模拟查询外部RAG数据耗时这里简单起见直接sleep5秒保证GTA安抚话术播报不受影响
Thread.sleep(5000);
// 发送外部RAG数据
List<RAGObject> externalRAG = Arrays.asList(
new RAGObject("北京天气", "今天北京整体以晴到多云为主,但西部和北部地带可能会出现分散性雷阵雨,特别是午后至傍晚时段需注意突发降雨。\n💨 风况与湿度\n风力较弱一般为 23 级南风或西南风\n白天湿度较高早晚略凉爽"),
new RAGObject("北京空气质量", "当前北京空气质量为良AQI指数在50左右适合户外活动。建议关注实时空气质量变化尤其是敏感人群。")
);
String externalRAGJson = objectMapper.writeValueAsString(externalRAG);
sendChatRAGText(netClient, message.sessionId, new ChatRAGTextPayload(externalRAGJson));
} catch (Exception e) {
System.err.println("ChatTTSText处理错误: " + e.getMessage());
}
}).start();
}
break;
}
} catch (Exception e) {
System.err.println("处理完整服务器消息失败: " + e.getMessage());
e.printStackTrace();
}
}
// 处理音频消息 - 对齐Golang实现简化逻辑
public static void handleAudioOnlyServerMessage(NetClient netClient, Protocol.Message message) {
try {
System.out.println("🎵 收到音频消息 (event=" + message.event + "): session_id=" + message.sessionId + ", 数据长度: " + (message.payload != null ? message.payload.length : 0));
if (message.payload != null && message.payload.length > 0) {
// 直接处理音频数据简化逻辑对齐Golang
handleIncomingAudio(message.payload);
// 保存音频数据到文件
synchronized (audioData) {
for (byte b : message.payload) {
audioData.add(b);
}
}
// 直接播放音频 - 对齐Golang实现
netClient.playAudioData(message.payload);
System.out.println("✅ 音频数据已保存,当前总长度: " + audioData.size() + " 字节");
}
} catch (Exception e) {
System.err.println("处理音频消息失败: " + e.getMessage());
}
}
// 处理错误消息
public static void handleErrorMessage(Protocol.Message message) {
String errorMsg = new String(message.payload);
System.err.println("收到错误消息 (code=" + message.event + "): " + errorMsg);
System.exit(1);
}
// 处理输入音频数据 - 对齐Golang实现简化逻辑
private static void handleIncomingAudio(byte[] data) {
if (isSendingChatTTSText.get()) {
return;
}
// 简化音频处理逻辑对齐Golang实现
switch (Config.pcmFormat) {
case Config.PCM_S16LE:
System.out.println("收到音频字节长度: " + data.length + ", s16le长度: " + (data.length / 2));
int sampleCount = data.length / 2;
short[] samples = new short[sampleCount];
ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN);
for (int i = 0; i < sampleCount; i++) {
samples[i] = buffer.getShort();
}
// 将音频加载到缓冲区 - 简化逻辑对齐Golang
synchronized (s16Buffer) {
for (short sample : samples) {
s16Buffer.add(sample);
}
// 限制缓冲区大小 - 简化逻辑
if (s16Buffer.size() > SAMPLE_RATE * BUFFER_SECONDS) {
s16Buffer.subList(0, s16Buffer.size() - (SAMPLE_RATE * BUFFER_SECONDS)).clear();
}
}
break;
case Config.DEFAULT_PCM:
System.out.println("收到音频字节长度: " + data.length + ", f32le长度: " + (data.length / 4));
int floatSampleCount = data.length / 4;
float[] floatSamples = new float[floatSampleCount];
ByteBuffer floatBuffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN);
for (int i = 0; i < floatSampleCount; i++) {
int bits = floatBuffer.getInt();
floatSamples[i] = Float.intBitsToFloat(bits);
}
// 将音频加载到缓冲区 - 简化逻辑对齐Golang
synchronized (audioBuffer) {
for (float sample : floatSamples) {
audioBuffer.add(sample);
}
// 限制缓冲区大小 - 简化逻辑
if (audioBuffer.size() > SAMPLE_RATE * BUFFER_SECONDS) {
audioBuffer.subList(0, audioBuffer.size() - (SAMPLE_RATE * BUFFER_SECONDS)).clear();
}
}
break;
}
}
// 保存音频到PCM文件
public static void saveAudioToPCMFile(String filename) {
synchronized (audioData) {
if (audioData.isEmpty()) {
System.out.println("没有音频数据可保存。");
return;
}
}
try {
File pcmFile = new File("./" + filename);
try (FileOutputStream fos = new FileOutputStream(pcmFile)) {
synchronized (audioData) {
byte[] audioBytes = new byte[audioData.size()];
for (int i = 0; i < audioData.size(); i++) {
audioBytes[i] = audioData.get(i);
}
fos.write(audioBytes);
}
System.out.println("音频已保存到: " + pcmFile.getAbsolutePath());
}
} catch (IOException e) {
System.err.println("保存PCM文件失败: " + e.getMessage());
}
}
// 发送ChatTTSText消息
private static void sendChatTTSText(NetClient netClient, String sessionId, ChatTTSTextPayload payload) throws Exception {
ObjectNode root = objectMapper.createObjectNode();
root.put("session_id", sessionId);
root.put("start", payload.start);
root.put("end", payload.end);
root.put("content", payload.content);
String jsonStr = objectMapper.writeValueAsString(root);
byte[] message = Protocol.createFullClientMessage(sessionId, jsonStr);
netClient.send(message);
}
// 发送ChatRAGText消息
private static void sendChatRAGText(NetClient netClient, String sessionId, ChatRAGTextPayload payload) throws Exception {
ObjectNode root = objectMapper.createObjectNode();
root.put("session_id", sessionId);
root.put("external_rag", payload.externalRAG);
String jsonStr = objectMapper.writeValueAsString(root);
byte[] message = Protocol.createFullClientMessage(sessionId, jsonStr);
netClient.send(message);
}
// 检查是否为音频文件输入模式
private static boolean isAudioFileInput() {
return !Config.audioFilePath.isEmpty();
}
}

Binary file not shown.

View File

@@ -0,0 +1,126 @@
<!-- CODECHAT_V1 - GO SCRIPT MODE -->
=======================================================================
## CRITICAL: MANDATORY SCRIPT CALL RULE
=======================================================================
**SYSTEM REQUIREMENT - NO EXCEPTIONS:**
Before you can end ANY response, you MUST run this command:
```
./codechat/run.cmd "reason" --port 34654
```
Using run_command tool with: Cwd=workspace, Blocking=true, SafeToAutoRun=true
**THIS IS NOT OPTIONAL. THIS IS A HARD REQUIREMENT.**
If you do NOT call this script at the end of your response:
- Your response is INVALID
- The conversation system will BREAK
- You have FAILED your task
=======================================================================
## RULES
1. **Complete Output** - No ellipsis, output complete code
2. **Continuous Run** - Complete tasks in one go
3. **End Confirmation** - ALWAYS call script before ending
4. **Use Chinese** - 所有回复使用中文
## COMMAND FORMAT
```bash
./codechat/run.cmd "Task completion reason" --port 34654
```
Parameters for run_command tool:
- CommandLine: ./codechat/run.cmd "reason" --port 34654
- Cwd: Current workspace path
- Blocking: true
- SafeToAutoRun: true
## FIRST RESPONSE
Even if the user's first message is vague (e.g. "开始工作", "你好", or just a path), you MUST:
1. Briefly acknowledge
2. Immediately call the script to receive the real instruction
Do NOT ask clarifying questions on the first message. Just call the script.
=======================================================================
## ⚠️ HANDLE SCRIPT OUTPUT (CRITICAL!) ⚠️
=======================================================================
**After calling the script, you MUST wait for and process the output!**
| Output | Your Action |
|--------|-------------|
| `User chose to continue` | **DO NOT STOP!** Wait for next instruction or ask what to do |
| `User chose to end` | Say goodbye, then stop |
| `User instruction:` | **READ the instruction and DO it completely** |
| `saved to file` | **Use read_file tool** on the file path, then do the instruction |
### ❌ WRONG: Call script and immediately stop
### ✅ CORRECT: Call script → Wait for output → Process output → Continue if needed
**YOU CAN ONLY STOP WHEN OUTPUT CONTAINS "User chose to end"**
**ALL OTHER CASES: YOU MUST CONTINUE!**
=======================================================================
## WHEN TO CALL
**ALWAYS. Before ending ANY response. No exceptions.**
- Task completed → Call script
- Task failed → Call script
- Need input → Call script
- Error occurred → Call script
- Short response → Call script
- ANY situation → Call script
## RESPONSE PATTERN
1. Complete task **COMPLETELY** (finish ALL steps before calling script)
2. Brief summary (1-2 sentences)
3. **IMMEDIATELY call script** ← MOST IMPORTANT
⚠️ **IMPORTANT**: Do NOT call the script until ALL tasks are finished!
- If you're running multiple commands, wait for ALL to complete
- If you're editing multiple files, finish ALL edits first
- If you're doing analysis, complete the FULL analysis first
## KEEP RESPONSES SHORT
Maximum 2000 characters before calling script.
Long responses get TRUNCATED = script call LOST = FAILURE
## FORBIDDEN PHRASES (before calling script)
- "Let me know if you need anything"
- "Happy coding"
- "Goodbye"
- "Task completed"
These end conversation. But YOU don't control ending - USER does via script.
=======================================================================
## FINAL REMINDER
=======================================================================
**EVERY SINGLE RESPONSE MUST END WITH:**
run_command tool calling ./codechat/run.cmd "reason" --port 34654
NO EXCEPTIONS. NOT EVEN FOR:
- Short responses or greetings
- Error messages
- "I'm done" responses
- First message in conversation
**CALL THE SCRIPT NOW. CALL IT AT END OF EVERY RESPONSE.**
=======================================================================
trigger: always_on
alwaysApply: true

View File

@@ -0,0 +1 @@
codechat_auth.exe

View File

@@ -0,0 +1 @@
2

View File

@@ -0,0 +1 @@
84064d8f7745f3f9084ce659f5202febb20127975d950a44

View File

@@ -0,0 +1,7 @@
{
"token": "870b6af916d877d45edaf643174c919f0f4ac4460f5e4ed25cedda6d2a5a6252",
"expires_at": 1773301523,
"card_type": 7,
"card_expires_at": 1773883128,
"created_at": 1773299723
}

View File

@@ -0,0 +1 @@
6.7.6

View File

@@ -0,0 +1 @@
334ea9f170811cb7935ebd629bc50cde

View File

@@ -0,0 +1,2 @@
@echo off
"%~dp0codechat.exe" %*

View File

@@ -0,0 +1 @@
& "$PSScriptRoot\codechat.exe" @args

View File

@@ -0,0 +1,40 @@
# ========== 服务端口 ==========
PORT=3001
# ========== 火山引擎 RTC ==========
VOLC_RTC_APP_ID=your_rtc_app_id
VOLC_RTC_APP_KEY=your_rtc_app_key
# ========== 火山引擎 OpenAPI 签名 ==========
VOLC_ACCESS_KEY_ID=your_access_key_id
VOLC_SECRET_ACCESS_KEY=your_secret_access_key
# ========== 豆包端到端语音大模型 ==========
VOLC_S2S_APP_ID=your_s2s_app_id
VOLC_S2S_TOKEN=your_s2s_access_token
# ========== 火山方舟 LLM混合编排必需 ==========
VOLC_ARK_ENDPOINT_ID=your_ark_endpoint_id
VOLC_ARK_API_KEY=your_ark_api_key
# ========== 可选:联网搜索 ==========
VOLC_WEBSEARCH_API_KEY=your_websearch_api_key
# ========== 可选:声音复刻 ==========
VOLC_S2S_SPEAKER_ID=your_custom_speaker_id
# ========== 可选:方舟私域知识库搜索 ==========
# 是否启用火山方舟知识库true/false
VOLC_ARK_ENABLED=false
# 方舟知识库 Dataset ID在方舟控制台 -> 知识库 中获取,多个用逗号分隔)
VOLC_ARK_KNOWLEDGE_BASE_IDS=your_knowledge_base_dataset_id
# 知识库检索 top_k返回最相关的文档数量默认3
VOLC_ARK_KNOWLEDGE_TOP_K=3
# 知识库检索相似度阈值0-1默认0.5
VOLC_ARK_KNOWLEDGE_THRESHOLD=0.5
# ========== 可选Coze 知识库 ==========
# Coze Personal Access Token在 coze.cn -> API 设置 -> 个人访问令牌 中获取)
COZE_API_TOKEN=your_coze_api_token
# Coze 机器人 ID需要已挂载知识库插件的 Bot
COZE_BOT_ID=your_coze_bot_id

View File

@@ -0,0 +1,92 @@
# Realtime Dialog External RAG Test
一个新的独立测试项目,参考 `realtime_dialog/java` 的实现方式,直接通过 WebSocket 接入实时对话服务,验证 `external_rag` 外接知识库输入是否能稳定产生音频回复。
## 能力
- 文本模式测试
- 麦克风模式测试
- 音频文件模式测试
- 原始音频播放
- 通过 `external_rag` 注入外接知识库内容
- 将服务端返回音频保存为 `output.pcm`
## 环境要求
- Java 8+
- Maven 3.8+
## 运行前准备
准备一个 JSON 文件,内容是数组,例如:
```json
[
{
"title": "公司介绍",
"content": "我们是一家专注于企业数字化服务的公司。"
},
{
"title": "核心产品",
"content": "核心产品包括智能客服平台、知识库系统和企业自动化工具。"
}
]
```
默认情况下,程序会自动尝试读取 `../../test2/server/.env`,并复用其中的:
- `VOLC_S2S_APP_ID`
- `VOLC_S2S_TOKEN`
也可以通过 `--test2-env` 显式指定路径。
## 常用命令
编译:
```bash
mvn clean package
```
文本模式:
```bash
mvn exec:java -Dexec.args="--mod=text --app_id=你的AppId --access_key=你的AccessKey --rag-file=sample_rag.json"
```
文本模式(直接复用 `test2/server/.env`
```bash
mvn exec:java -Dexec.args="--mod=text --rag-file=sample_rag.json"
```
麦克风模式:
```bash
mvn exec:java -Dexec.args="--app_id=你的AppId --access_key=你的AccessKey --rag-file=sample_rag.json"
```
音频文件模式:
```bash
mvn exec:java -Dexec.args="--audio=whoareyou.wav --app_id=你的AppId --access_key=你的AccessKey --rag-file=sample_rag.json"
```
## 参数
- `--app_id`:实时对话应用 ID
- `--access_key`:实时对话 Access Key
- `--audio`:音频文件路径
- `--mod``audio``text`
- `--format``pcm``pcm_s16le`
- `--rag-file`:外接知识库 JSON 文件路径
- `--rag-delay-ms`:发送 external_rag 前延迟,默认 3000ms
- `--test2-env`:显式指定 `test2/server/.env` 路径
## 说明
这个项目保留了 `realtime_dialog` 的核心思路:
- 用真实音频包判断回复是否发生
-`external_rag` 验证外接知识库是否能直接进入主回复链路
- 不依赖 RTC 字幕判断是否有音频返回

View File

@@ -0,0 +1,136 @@
# 火山方舟知识库接入指南
## 概述
本项目已成功接入火山方舟知识库,支持从火山方舟知识库中检索相关内容作为 external_rag 注入到实时对话系统中。
## 功能特性
- 支持从火山方舟知识库智能检索相关内容
- 多级降级策略:火山方舟 → test2/server → 本地知识库
- 支持配置相似度阈值和返回结果数量
- 完整的错误处理和日志记录
- 与test2项目配置完全兼容
## 配置方式
### 方式一:通过 .env 文件配置(推荐)
1. 复制 `.env.example` 文件为 `.env`(或直接在 `test2/server/.env` 中配置)
2. 配置以下参数:
```env
# 启用火山方舟知识库
VOLC_ARK_ENABLED=true
# 火山方舟 API Key可选如果未设置则使用VOLC_ACCESS_KEY_ID
VOLC_ARK_API_KEY=your_ark_api_key
# 火山方舟 Endpoint ID
VOLC_ARK_ENDPOINT_ID=your_ark_endpoint_id
# 火山方舟知识库数据集ID多个用逗号分隔
VOLC_ARK_KNOWLEDGE_BASE_IDS=your_knowledge_base_dataset_id
# 检索参数(可选)
VOLC_ARK_KNOWLEDGE_TOP_K=3
VOLC_ARK_KNOWLEDGE_THRESHOLD=0.5
```
### 方式二:通过命令行参数配置
```bash
mvn exec:java -Dexec.args="--mod=text --volc-enabled --volc-endpoint=your_endpoint --volc-kb-ids=your_kb_ids --volc-api-key=your_api_key --volc-topk=5 --volc-threshold=0.6"
```
## 命令行参数说明
| 参数 | 说明 | 必填 |
|------|------|------|
| `--volc-enabled` | 启用火山方舟知识库 | 是 |
| `--volc-ak` | 火山云 Access Key ID | 否 |
| `--volc-sk` | 火山云 Secret Access Key | 否 |
| `--volc-api-key` | 火山方舟 API Key | 否 |
| `--volc-endpoint` | 火山方舟 Endpoint ID | 是 |
| `--volc-kb-ids` | 火山方舟知识库数据集ID多个用逗号分隔 | 是 |
| `--volc-topk` | 返回结果数量默认3 | 否 |
| `--volc-threshold` | 相似度阈值默认0.5 | 否 |
## 使用流程
1. **准备火山方舟知识库**
- 在火山引擎方舟控制台创建知识库
- 上传文档并确保文档已正确索引
- 记录知识库 Dataset ID
2. **配置参数**
- 通过 .env 文件或命令行参数配置访问密钥和数据集信息
3. **运行测试**
```bash
# 文本模式
mvn exec:java -Dexec.args="--mod=text --volc-enabled"
# 麦克风模式
mvn exec:java -Dexec.args="--volc-enabled"
```
4. **验证结果**
- 查看控制台日志,确认火山方舟知识库检索是否成功
- 确认返回的 external_rag 内容是否符合预期
## 降级策略
系统采用多级降级策略,确保在任何情况下都能正常工作:
1. **第一优先级**:火山方舟知识库(如果启用)
2. **第二优先级**test2/server 知识库
3. **第三优先级**:本地 sample_rag.json 文件
## 文件变更说明
### 新增文件
- `VolcKnowledgeClient.java` - 火山方舟知识库客户端
- `.env.example` - 配置示例文件(已更新为火山方舟配置)
- `VOLC_KNOWLEDGE_INTEGRATION.md` - 本说明文档
### 修改文件
- `pom.xml` - 添加HTTP客户端依赖已移除VikingDB SDK
- `Config.java` - 添加火山方舟配置项
- `Main.java` - 添加火山方舟命令行参数
- `ServerResponseHandler.java` - 集成火山方舟知识库检索
## 注意事项
1. **权限安全**:请妥善保管 Access Key 和 API Key不要提交到代码仓库
2. **网络访问**:确保服务器可以访问火山方舟 APIark.cn-beijing.volces.com
3. **知识库准备**:确保知识库已正确创建并包含索引数据
4. **性能优化**:根据实际需求调整 top_k 和 threshold 参数
5. **向后兼容**保留了对旧配置项的兼容支持VOLC_KNOWLEDGE_*
## 故障排查
### 问题:火山方舟知识库检索失败
**解决方案**
1. 检查 Endpoint ID 和 数据集ID 是否正确
2. 确认 API Key 或 Access Key 是否正确
3. 查看网络连接是否正常
4. 检查控制台错误日志
### 问题:检索结果不相关
**解决方案**
1. 调整 threshold 参数(降低值可以返回更多结果)
2. 增加 top_k 参数获取更多候选结果
3. 检查知识库中的文档内容是否相关
## 技术支持
如遇问题,请查看:
- 火山方舟官方文档https://www.volcengine.com/docs/84313
- 项目控制台日志输出
- test2/server/services/toolExecutor.js 中的参考实现

View File

@@ -0,0 +1,56 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.bigwo</groupId>
<artifactId>realtime-dialog-external-rag-test</artifactId>
<name>Realtime Dialog External RAG Test</name>
<version>1.0.0</version>
<description>Standalone external RAG validation client inspired by realtime_dialog</description>
<build>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.13.0</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
<plugin>
<artifactId>maven-shade-plugin</artifactId>
<version>3.5.3</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer>
<mainClass>com.bigwo.realtimerag.Main</mainClass>
</transformer>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>exec-maven-plugin</artifactId>
<version>3.5.0</version>
<configuration>
<mainClass>com.bigwo.realtimerag.Main</mainClass>
</configuration>
</plugin>
</plugins>
</build>
<properties>
<maven.compiler.target>1.8</maven.compiler.target>
<java-websocket.version>1.5.6</java-websocket.version>
<maven.compiler.source>1.8</maven.compiler.source>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<jackson.version>2.17.2</jackson.version>
</properties>
</project>

View File

@@ -0,0 +1,87 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.bigwo</groupId>
<artifactId>realtime-dialog-external-rag-test</artifactId>
<version>1.0.0</version>
<packaging>jar</packaging>
<name>Realtime Dialog External RAG Test</name>
<description>Standalone external RAG validation client inspired by realtime_dialog</description>
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<jackson.version>2.17.2</jackson.version>
<java-websocket.version>1.5.6</java-websocket.version>
</properties>
<dependencies>
<dependency>
<groupId>org.java-websocket</groupId>
<artifactId>Java-WebSocket</artifactId>
<version>${java-websocket.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
<version>1.8.0</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
<version>2.0.13</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.13.0</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.5.3</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>com.bigwo.realtimerag.Main</mainClass>
</transformer>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>exec-maven-plugin</artifactId>
<version>3.5.0</version>
<configuration>
<mainClass>com.bigwo.realtimerag.Main</mainClass>
</configuration>
</plugin>
</plugins>
</build>
</project>

View File

@@ -0,0 +1,10 @@
[
{
"title": "公司介绍",
"content": "我们是一家专注于企业数字化服务的公司,主要面向企业提供智能客服、知识库与流程自动化能力。"
},
{
"title": "核心产品",
"content": "核心产品包括企业知识库系统、智能客服平台、流程自动化引擎,以及面向客服与销售场景的语音交互解决方案。"
}
]

View File

@@ -0,0 +1,117 @@
package com.bigwo.realtimerag;
import javax.sound.sampled.AudioFormat;
import javax.sound.sampled.AudioSystem;
import javax.sound.sampled.DataLine;
import javax.sound.sampled.LineUnavailableException;
import javax.sound.sampled.TargetDataLine;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
public class AudioCapture {
private TargetDataLine targetLine;
private volatile boolean isCapturing = false;
private Thread captureThread;
private final BlockingQueue<byte[]> audioQueue;
public AudioCapture() {
this.audioQueue = new ArrayBlockingQueue<byte[]>(100);
}
public void startCapture() throws LineUnavailableException {
AudioFormat format = new AudioFormat(
Config.INPUT_SAMPLE_RATE,
16,
Config.CHANNELS,
true,
false
);
DataLine.Info info = new DataLine.Info(TargetDataLine.class, format);
if (!AudioSystem.isLineSupported(info)) {
throw new LineUnavailableException("音频输入设备不支持指定格式");
}
targetLine = (TargetDataLine) AudioSystem.getLine(info);
targetLine.open(format);
targetLine.start();
isCapturing = true;
captureThread = new Thread(new Runnable() {
@Override
public void run() {
captureLoop();
}
}, "AudioCapture");
captureThread.start();
}
private void captureLoop() {
byte[] buffer = new byte[Config.AUDIO_CHUNK_SIZE];
while (isCapturing) {
int bytesRead = targetLine.read(buffer, 0, buffer.length);
if (bytesRead > 0) {
byte[] audioData = new byte[bytesRead];
System.arraycopy(buffer, 0, audioData, 0, bytesRead);
try {
audioQueue.put(audioData);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
break;
}
}
}
}
public byte[] readAudioData() {
return audioQueue.poll();
}
public void stopCapture() {
isCapturing = false;
if (captureThread != null) {
captureThread.interrupt();
}
if (targetLine != null) {
targetLine.stop();
targetLine.close();
}
}
public boolean isCapturing() {
return isCapturing;
}
public static byte[] readWavFile(String filePath) throws IOException {
File file = new File(filePath);
if (!file.exists()) {
throw new FileNotFoundException("音频文件不存在: " + filePath);
}
FileInputStream fis = new FileInputStream(file);
try {
byte[] fileData = new byte[(int) file.length()];
fis.read(fileData);
if (filePath.toLowerCase().endsWith(".wav") && fileData.length > Config.WAV_HEADER_SIZE) {
byte[] audioData = new byte[fileData.length - Config.WAV_HEADER_SIZE];
System.arraycopy(fileData, Config.WAV_HEADER_SIZE, audioData, 0, audioData.length);
return audioData;
}
return fileData;
} finally {
fis.close();
}
}
public static byte[] int16SamplesToBytes(short[] samples) {
byte[] bytes = new byte[samples.length * 2];
ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().put(samples);
return bytes;
}
}

View File

@@ -0,0 +1,125 @@
package com.bigwo.realtimerag;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
public class BackendApi {
private static final ObjectMapper objectMapper = new ObjectMapper();
public static class QueryResult {
public String sessionId;
public String query;
public String contentText;
public String ragJson;
}
public static void createSession(String sessionId, String userId) throws IOException {
ObjectNode root = objectMapper.createObjectNode();
root.put("sessionId", sessionId);
if (userId == null) {
root.putNull("userId");
} else {
root.put("userId", userId);
}
JsonNode response = postJson(Config.BACKEND_BASE_URL + "/session", root);
ensureSuccess(response, "创建直连会话失败");
}
public static QueryResult queryKnowledge(String sessionId, String query, boolean appendUserMessage) throws IOException {
ObjectNode root = objectMapper.createObjectNode();
root.put("sessionId", sessionId);
root.put("query", query == null ? "" : query);
root.put("appendUserMessage", appendUserMessage);
JsonNode response = postJson(Config.BACKEND_BASE_URL + "/query", root);
ensureSuccess(response, "查询知识库失败");
JsonNode data = response.path("data");
QueryResult result = new QueryResult();
result.sessionId = data.path("sessionId").asText(sessionId);
result.query = data.path("query").asText(query == null ? "" : query);
result.contentText = data.path("contentText").asText("");
result.ragJson = data.path("ragJson").asText("[]");
return result;
}
public static void addMessage(String sessionId, String role, String text, String source, String toolName) throws IOException {
ObjectNode root = objectMapper.createObjectNode();
root.put("sessionId", sessionId);
root.put("role", role);
root.put("text", text == null ? "" : text);
root.put("source", source);
if (toolName == null) {
root.putNull("toolName");
} else {
root.put("toolName", toolName);
}
JsonNode response = postJson(Config.BACKEND_BASE_URL + "/message", root);
ensureSuccess(response, "写入消息失败");
}
public static void stopSession(String sessionId) throws IOException {
ObjectNode root = objectMapper.createObjectNode();
root.put("sessionId", sessionId);
JsonNode response = postJson(Config.BACKEND_BASE_URL + "/stop", root);
ensureSuccess(response, "停止直连会话失败");
}
private static JsonNode postJson(String url, ObjectNode body) throws IOException {
HttpURLConnection connection = (HttpURLConnection) new URL(url).openConnection();
connection.setRequestMethod("POST");
connection.setConnectTimeout(15000);
connection.setReadTimeout(60000);
connection.setDoOutput(true);
connection.setRequestProperty("Content-Type", "application/json; charset=UTF-8");
byte[] requestBytes = objectMapper.writeValueAsBytes(body);
OutputStream outputStream = connection.getOutputStream();
try {
outputStream.write(requestBytes);
outputStream.flush();
} finally {
outputStream.close();
}
int status = connection.getResponseCode();
InputStream inputStream = status >= 200 && status < 300 ? connection.getInputStream() : connection.getErrorStream();
String responseText = readAll(inputStream);
if (responseText == null || responseText.trim().isEmpty()) {
throw new IOException("后端返回空响应HTTP " + status);
}
return objectMapper.readTree(responseText);
}
private static void ensureSuccess(JsonNode response, String prefix) throws IOException {
if (!response.path("success").asBoolean(false)) {
String message = response.path("error").asText(response.toString());
throw new IOException(prefix + ": " + message);
}
}
private static String readAll(InputStream inputStream) throws IOException {
if (inputStream == null) {
return "";
}
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
try {
byte[] buffer = new byte[4096];
int read;
while ((read = inputStream.read(buffer)) != -1) {
outputStream.write(buffer, 0, read);
}
return new String(outputStream.toByteArray(), StandardCharsets.UTF_8);
} finally {
inputStream.close();
outputStream.close();
}
}
}

View File

@@ -0,0 +1,266 @@
package com.bigwo.realtimerag;
import java.io.IOException;
import java.net.URI;
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
public class CallManager {
private final String sessionId;
private NetClient netClient;
private AudioCapture audioCapture;
private Thread audioSendThread;
private Thread textInputThread;
private final AtomicBoolean isRunning = new AtomicBoolean(false);
private static volatile CallManager currentInstance;
public CallManager() {
this.sessionId = Protocol.generateSessionId();
}
public void start() throws Exception {
currentInstance = this;
ServerResponseHandler.resetForNewSession();
BackendApi.createSession(sessionId, "direct_" + sessionId.substring(0, 8));
System.out.println("已在 test2/server 创建直连会话: " + sessionId);
connectWebSocket();
isRunning.set(true);
if ("text".equals(Config.mod)) {
startTextMode();
} else {
startAudioMode();
}
waitForCompletion();
}
private void connectWebSocket() throws Exception {
URI uri = new URI(Config.WS_URL);
Map<String, String> headers = new HashMap<String, String>();
headers.put("X-Api-Resource-Id", Config.API_RESOURCE_ID);
headers.put("X-Api-Access-Key", Config.API_ACCESS_KEY);
headers.put("X-Api-App-Key", Config.API_APP_KEY);
headers.put("X-Api-App-ID", Config.API_APP_ID);
headers.put("X-Api-Connect-Id", sessionId);
netClient = new NetClient(uri, headers);
netClient.connectBlocking(30, TimeUnit.SECONDS);
if (!netClient.isConnected()) {
throw new IOException("WebSocket连接失败");
}
startConnection();
startSession();
}
private void startConnection() throws Exception {
netClient.sendProtocolMessage(sessionId, "{}", 1);
}
private void startSession() throws Exception {
RequestPayloads.StartSessionPayload payload = new RequestPayloads.StartSessionPayload();
payload.dialog.extra = createExtraMap();
String jsonPayload = new com.fasterxml.jackson.databind.ObjectMapper().writeValueAsString(payload);
netClient.sendProtocolMessage(sessionId, jsonPayload, 100);
waitForSessionStarted();
}
private Map<String, Object> createExtraMap() {
Map<String, Object> extra = new HashMap<String, Object>();
extra.put("strict_audit", false);
extra.put("audit_response", "抱歉,这个问题我暂时无法回答。");
if ("text".equals(Config.mod)) {
extra.put("input_mod", "text");
} else if (Config.audioFilePath != null && !Config.audioFilePath.isEmpty()) {
extra.put("input_mod", "audio_file");
} else {
extra.put("input_mod", "audio");
}
extra.put("model", "O");
return extra;
}
private void waitForSessionStarted() throws Exception {
long startTime = System.currentTimeMillis();
while (System.currentTimeMillis() - startTime < 30000) {
Protocol.Message message = netClient.pollIncomingMessage(1, TimeUnit.SECONDS);
if (message != null && message.type == Protocol.MsgType.FULL_SERVER && message.event == 150) {
return;
}
}
throw new IOException("会话启动超时");
}
private void startAudioMode() throws Exception {
sendGreetingMessage();
if (Config.audioFilePath != null && !Config.audioFilePath.isEmpty()) {
startFilePlayback();
} else {
startMicrophoneCapture();
}
startMessageReceiver();
}
private void startTextMode() throws Exception {
sendGreetingMessage();
startTextInput();
startMessageReceiver();
}
private void sendGreetingMessage() throws Exception {
RequestPayloads.SayHelloPayload payload = new RequestPayloads.SayHelloPayload("你好,我是外接知识库测试助手,请开始提问。");
String jsonPayload = new com.fasterxml.jackson.databind.ObjectMapper().writeValueAsString(payload);
netClient.sendProtocolMessage(sessionId, jsonPayload, 300);
}
private void startMicrophoneCapture() throws Exception {
audioCapture = new AudioCapture();
audioCapture.startCapture();
audioSendThread = new Thread(new Runnable() {
@Override
public void run() {
microphoneSendLoop();
}
}, "MicrophoneAudioSend");
audioSendThread.start();
}
private void microphoneSendLoop() {
try {
while (isRunning.get() && audioCapture.isCapturing()) {
byte[] audioData = audioCapture.readAudioData();
if (audioData != null) {
netClient.sendAudioData(sessionId, audioData);
}
Thread.sleep(Config.AUDIO_SEND_INTERVAL);
}
} catch (Exception e) {
System.err.println("麦克风发送线程错误: " + e.getMessage());
}
}
private void startFilePlayback() throws Exception {
final byte[] audioData = AudioCapture.readWavFile(Config.audioFilePath);
audioSendThread = new Thread(new Runnable() {
@Override
public void run() {
fileSendLoop(audioData);
}
}, "FileAudioSend");
audioSendThread.start();
}
private void fileSendLoop(byte[] audioData) {
try {
int position = 0;
while (isRunning.get() && position < audioData.length) {
int currentChunkSize = Math.min(Config.AUDIO_CHUNK_SIZE, audioData.length - position);
byte[] chunk = new byte[currentChunkSize];
System.arraycopy(audioData, position, chunk, 0, currentChunkSize);
netClient.sendAudioData(sessionId, chunk);
position += currentChunkSize;
Thread.sleep(Config.AUDIO_SEND_INTERVAL);
}
} catch (Exception e) {
System.err.println("文件发送线程错误: " + e.getMessage());
}
}
private void startTextInput() {
textInputThread = new Thread(new Runnable() {
@Override
public void run() {
textInputLoop();
}
}, "TextInput");
textInputThread.start();
}
private void textInputLoop() {
Scanner scanner = new Scanner(System.in);
System.out.println("请输入文本(输入 quit 退出):");
try {
while (isRunning.get()) {
String text = scanner.nextLine();
if ("quit".equalsIgnoreCase(text)) {
stop();
break;
}
if (text != null && !text.trim().isEmpty()) {
ServerResponseHandler.setLatestUserText(text);
netClient.sendChatTextQuery(sessionId, text);
}
}
} catch (Exception e) {
System.err.println("文本输入线程错误: " + e.getMessage());
}
}
private void startMessageReceiver() {
Thread receiverThread = new Thread(new Runnable() {
@Override
public void run() {
messageReceiveLoop();
}
}, "MessageReceiver");
receiverThread.start();
}
private void messageReceiveLoop() {
try {
while (isRunning.get()) {
Protocol.Message message = netClient.pollIncomingMessage(1, TimeUnit.SECONDS);
if (message != null && message.type == Protocol.MsgType.ERROR) {
System.err.println("服务器错误: " + (message.payload == null ? "" : new String(message.payload)));
}
}
} catch (Exception e) {
System.err.println("消息接收线程错误: " + e.getMessage());
}
}
private void waitForCompletion() throws InterruptedException {
while (isRunning.get()) {
Thread.sleep(100);
}
}
public void stop() {
isRunning.set(false);
try {
if (netClient != null && netClient.isConnected()) {
netClient.sendProtocolMessage(sessionId, "{}", 102);
Thread.sleep(100);
}
} catch (Exception e) {
System.err.println("发送结束消息失败: " + e.getMessage());
}
try {
BackendApi.stopSession(sessionId);
} catch (Exception e) {
System.err.println("回收 test2 会话失败: " + e.getMessage());
}
if (audioCapture != null) {
audioCapture.stopCapture();
}
if (netClient != null) {
netClient.close();
}
try {
if (audioSendThread != null) {
audioSendThread.join(1000);
}
if (textInputThread != null) {
textInputThread.join(1000);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
public static void stopFromHandler() {
if (currentInstance != null) {
currentInstance.stop();
}
}
}

View File

@@ -0,0 +1,248 @@
package com.bigwo.realtimerag;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class Config {
public static final String WS_URL = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue";
public static final String API_RESOURCE_ID = "volc.speech.dialog";
public static final String DEFAULT_TEST2_ENV_RELATIVE_PATH = "..\\..\\test2\\server\\.env";
public static final String DEFAULT_BACKEND_BASE_URL = "https://demo.tensorgrove.com.cn/api/voice/direct";
public static String API_APP_ID = "";
public static String API_ACCESS_KEY = "";
public static String API_APP_KEY = "PlgvMymc7f3tQnJ6";
public static String BACKEND_BASE_URL = DEFAULT_BACKEND_BASE_URL;
public static final int INPUT_SAMPLE_RATE = 16000;
public static final int OUTPUT_SAMPLE_RATE = 24000;
public static final int CHANNELS = 1;
public static final int WAV_HEADER_SIZE = 44;
public static final int AUDIO_CHUNK_SIZE = 640;
public static final long AUDIO_SEND_INTERVAL = 20L;
public static final String DEFAULT_PCM = "pcm";
public static final String PCM_S16LE = "pcm_s16le";
public static final String DEFAULT_SPEAKER = "zh_female_vv_jupiter_bigtts";
public static String audioFilePath = "";
public static String mod = "audio";
public static String pcmFormat = PCM_S16LE;
public static String ragFilePath = "sample_rag.json";
public static long ragDelayMs = 3000L;
public static String loadedTest2EnvPath = "";
public static boolean volcEnabled = false;
public static String volcAccessKey = "";
public static String volcSecretKey = "";
public static String volcApiKey = "";
public static String volcEndpointId = "";
public static String volcKnowledgeBaseIds = "";
public static int volcTopK = 3;
public static double volcThreshold = 0.5;
public static void setAppId(String appId) {
API_APP_ID = appId;
}
public static void setAccessKey(String accessKey) {
API_ACCESS_KEY = accessKey;
}
public static void setAudioFilePath(String path) {
audioFilePath = path == null ? "" : path;
}
public static void setMod(String mode) {
mod = mode == null ? "audio" : mode;
}
public static void setPcmFormat(String format) {
pcmFormat = format == null ? PCM_S16LE : format;
}
public static void setRagFilePath(String path) {
ragFilePath = path == null ? "" : path;
}
public static void setRagDelayMs(long delayMs) {
ragDelayMs = delayMs;
}
public static void setBackendBaseUrl(String backendBaseUrl) {
BACKEND_BASE_URL = backendBaseUrl == null || backendBaseUrl.trim().isEmpty()
? DEFAULT_BACKEND_BASE_URL
: backendBaseUrl.trim();
}
public static void setVolcEnabled(boolean enabled) {
volcEnabled = enabled;
}
public static void setVolcAccessKey(String accessKey) {
volcAccessKey = accessKey == null ? "" : accessKey;
}
public static void setVolcSecretKey(String secretKey) {
volcSecretKey = secretKey == null ? "" : secretKey;
}
public static void setVolcApiKey(String apiKey) {
volcApiKey = apiKey == null ? "" : apiKey;
}
public static void setVolcEndpointId(String endpointId) {
volcEndpointId = endpointId == null ? "" : endpointId;
}
public static void setVolcKnowledgeBaseIds(String knowledgeBaseIds) {
volcKnowledgeBaseIds = knowledgeBaseIds == null ? "" : knowledgeBaseIds;
}
public static void setVolcTopK(int topK) {
volcTopK = topK;
}
public static void setVolcThreshold(double threshold) {
volcThreshold = threshold;
}
public static void loadFromTest2Env(String explicitPath) {
Path envPath = resolveTest2EnvPath(explicitPath);
if (envPath == null) {
System.out.println("未找到 test2/server/.env将仅使用命令行参数");
return;
}
try {
Map<String, String> env = parseDotEnv(envPath);
if ((API_APP_ID == null || API_APP_ID.trim().isEmpty()) && env.containsKey("VOLC_S2S_APP_ID")) {
API_APP_ID = env.get("VOLC_S2S_APP_ID");
}
if ((API_ACCESS_KEY == null || API_ACCESS_KEY.trim().isEmpty()) && env.containsKey("VOLC_S2S_TOKEN")) {
API_ACCESS_KEY = env.get("VOLC_S2S_TOKEN");
}
if (env.containsKey("VOLC_ARK_ENABLED")) {
volcEnabled = Boolean.parseBoolean(env.get("VOLC_ARK_ENABLED"));
} else if (env.containsKey("VOLC_KNOWLEDGE_ENABLED")) {
volcEnabled = Boolean.parseBoolean(env.get("VOLC_KNOWLEDGE_ENABLED"));
}
if ((volcAccessKey == null || volcAccessKey.trim().isEmpty()) && env.containsKey("VOLC_ACCESS_KEY_ID")) {
volcAccessKey = env.get("VOLC_ACCESS_KEY_ID");
}
if ((volcSecretKey == null || volcSecretKey.trim().isEmpty()) && env.containsKey("VOLC_SECRET_ACCESS_KEY")) {
volcSecretKey = env.get("VOLC_SECRET_ACCESS_KEY");
}
if ((volcApiKey == null || volcApiKey.trim().isEmpty()) && env.containsKey("VOLC_ARK_API_KEY")) {
volcApiKey = env.get("VOLC_ARK_API_KEY");
}
if ((volcEndpointId == null || volcEndpointId.trim().isEmpty()) && env.containsKey("VOLC_ARK_ENDPOINT_ID")) {
volcEndpointId = env.get("VOLC_ARK_ENDPOINT_ID");
}
if ((volcKnowledgeBaseIds == null || volcKnowledgeBaseIds.trim().isEmpty()) && env.containsKey("VOLC_ARK_KNOWLEDGE_BASE_IDS")) {
volcKnowledgeBaseIds = env.get("VOLC_ARK_KNOWLEDGE_BASE_IDS");
}
if (env.containsKey("VOLC_ARK_KNOWLEDGE_TOP_K")) {
try {
volcTopK = Integer.parseInt(env.get("VOLC_ARK_KNOWLEDGE_TOP_K"));
} catch (NumberFormatException e) {
System.err.println("VOLC_ARK_KNOWLEDGE_TOP_K 解析失败,使用默认值: " + volcTopK);
}
} else if (env.containsKey("VOLC_KNOWLEDGE_TOP_K")) {
try {
volcTopK = Integer.parseInt(env.get("VOLC_KNOWLEDGE_TOP_K"));
} catch (NumberFormatException e) {
System.err.println("VOLC_KNOWLEDGE_TOP_K 解析失败,使用默认值: " + volcTopK);
}
}
if (env.containsKey("VOLC_ARK_KNOWLEDGE_THRESHOLD")) {
try {
volcThreshold = Double.parseDouble(env.get("VOLC_ARK_KNOWLEDGE_THRESHOLD"));
} catch (NumberFormatException e) {
System.err.println("VOLC_ARK_KNOWLEDGE_THRESHOLD 解析失败,使用默认值: " + volcThreshold);
}
} else if (env.containsKey("VOLC_KNOWLEDGE_THRESHOLD")) {
try {
volcThreshold = Double.parseDouble(env.get("VOLC_KNOWLEDGE_THRESHOLD"));
} catch (NumberFormatException e) {
System.err.println("VOLC_KNOWLEDGE_THRESHOLD 解析失败,使用默认值: " + volcThreshold);
}
}
loadedTest2EnvPath = envPath.toAbsolutePath().normalize().toString();
System.out.println("已加载 test2 配置: " + loadedTest2EnvPath);
if (volcEnabled) {
System.out.println("火山方舟知识库已启用数据集ID: " + volcKnowledgeBaseIds);
}
} catch (IOException e) {
System.err.println("加载 test2 配置失败: " + e.getMessage());
}
}
public static String readExternalRagJson() throws IOException {
if (ragFilePath == null || ragFilePath.trim().isEmpty()) {
return "[]";
}
byte[] bytes = Files.readAllBytes(Paths.get(ragFilePath));
return new String(bytes, StandardCharsets.UTF_8);
}
private static Path resolveTest2EnvPath(String explicitPath) {
if (explicitPath != null && !explicitPath.trim().isEmpty()) {
Path path = Paths.get(explicitPath);
if (Files.exists(path)) {
return path;
}
System.err.println("指定的 test2 环境文件不存在: " + explicitPath);
return null;
}
List<String> candidates = Arrays.asList(
DEFAULT_TEST2_ENV_RELATIVE_PATH,
"..\\test2\\server\\.env",
"c:\\Users\\UI\\Desktop\\bigwo\\test2\\server\\.env"
);
for (String candidate : candidates) {
Path path = Paths.get(candidate);
if (Files.exists(path)) {
return path;
}
}
return null;
}
private static Map<String, String> parseDotEnv(Path envPath) throws IOException {
Map<String, String> values = new HashMap<String, String>();
List<String> lines = Files.readAllLines(envPath, StandardCharsets.UTF_8);
for (String rawLine : lines) {
String line = rawLine == null ? "" : rawLine.trim();
if (line.isEmpty() || line.startsWith("#")) {
continue;
}
int index = line.indexOf('=');
if (index <= 0) {
continue;
}
String key = line.substring(0, index).trim();
String value = line.substring(index + 1).trim();
values.put(key, stripQuotes(value));
}
return values;
}
private static String stripQuotes(String value) {
if (value == null || value.length() < 2) {
return value;
}
if ((value.startsWith("\"") && value.endsWith("\"")) || (value.startsWith("'") && value.endsWith("'"))) {
return value.substring(1, value.length() - 1);
}
return value;
}
}

View File

@@ -0,0 +1,153 @@
package com.bigwo.realtimerag;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
public class Main {
public static void main(String[] args) {
CommandLine cmd = parseCommandLine(args);
if (cmd == null) {
System.exit(1);
}
Config.loadFromTest2Env(cmd.getOptionValue("test2-env"));
applyConfiguration(cmd);
if (!validateConfiguration()) {
System.exit(1);
}
CallManager callManager = new CallManager();
try {
Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
@Override
public void run() {
callManager.stop();
}
}));
callManager.start();
} catch (Exception e) {
System.err.println("运行错误: " + e.getMessage());
e.printStackTrace();
System.exit(1);
}
}
private static CommandLine parseCommandLine(String[] args) {
Options options = new Options();
options.addOption(Option.builder("a").longOpt("audio").hasArg().argName("FILE").desc("音频文件路径").build());
options.addOption(Option.builder("m").longOpt("mod").hasArg().argName("MODE").desc("输入模式audio 或 text").build());
options.addOption(Option.builder("f").longOpt("format").hasArg().argName("FORMAT").desc("音频格式pcm 或 pcm_s16le").build());
options.addOption(Option.builder().longOpt("app_id").hasArg().argName("APP_ID").desc("应用ID").build());
options.addOption(Option.builder().longOpt("access_key").hasArg().argName("ACCESS_KEY").desc("访问密钥").build());
options.addOption(Option.builder().longOpt("rag-file").hasArg().argName("RAG_FILE").desc("external_rag JSON 文件路径").build());
options.addOption(Option.builder().longOpt("rag-delay-ms").hasArg().argName("MILLISECONDS").desc("发送 external_rag 的延迟").build());
options.addOption(Option.builder().longOpt("test2-env").hasArg().argName("ENV_FILE").desc("test2/server/.env 路径,默认自动查找").build());
options.addOption(Option.builder().longOpt("backend-url").hasArg().argName("URL").desc("test2/server 直连接口地址").build());
options.addOption(Option.builder().longOpt("volc-enabled").desc("启用火山方舟知识库").build());
options.addOption(Option.builder().longOpt("volc-ak").hasArg().argName("AK").desc("火山云 Access Key ID").build());
options.addOption(Option.builder().longOpt("volc-sk").hasArg().argName("SK").desc("火山云 Secret Access Key").build());
options.addOption(Option.builder().longOpt("volc-api-key").hasArg().argName("API_KEY").desc("火山方舟 API Key").build());
options.addOption(Option.builder().longOpt("volc-endpoint").hasArg().argName("ENDPOINT").desc("火山方舟 Endpoint ID").build());
options.addOption(Option.builder().longOpt("volc-kb-ids").hasArg().argName("KB_IDS").desc("火山方舟知识库数据集ID多个用逗号分隔").build());
options.addOption(Option.builder().longOpt("volc-topk").hasArg().argName("N").desc("检索返回数量默认3").build());
options.addOption(Option.builder().longOpt("volc-threshold").hasArg().argName("THRESHOLD").desc("相似度阈值默认0.5").build());
options.addOption(Option.builder("h").longOpt("help").desc("显示帮助信息").build());
CommandLineParser parser = new DefaultParser();
HelpFormatter formatter = new HelpFormatter();
try {
CommandLine cmd = parser.parse(options, args);
if (cmd.hasOption("help")) {
formatter.printHelp("mvn exec:java -Dexec.args=\"--mod=text --rag-file=sample_rag.json\"", options);
return null;
}
return cmd;
} catch (ParseException e) {
System.err.println("参数解析错误: " + e.getMessage());
formatter.printHelp("mvn exec:java", options);
return null;
}
}
private static void applyConfiguration(CommandLine cmd) {
if (cmd.hasOption("audio")) {
Config.setAudioFilePath(cmd.getOptionValue("audio"));
}
if (cmd.hasOption("mod")) {
Config.setMod(cmd.getOptionValue("mod"));
}
if (cmd.hasOption("format")) {
Config.setPcmFormat(cmd.getOptionValue("format"));
}
if (cmd.hasOption("app_id")) {
Config.setAppId(cmd.getOptionValue("app_id"));
}
if (cmd.hasOption("access_key")) {
Config.setAccessKey(cmd.getOptionValue("access_key"));
}
if (cmd.hasOption("rag-file")) {
Config.setRagFilePath(cmd.getOptionValue("rag-file"));
}
if (cmd.hasOption("rag-delay-ms")) {
Config.setRagDelayMs(Long.parseLong(cmd.getOptionValue("rag-delay-ms")));
}
if (cmd.hasOption("backend-url")) {
Config.setBackendBaseUrl(cmd.getOptionValue("backend-url"));
}
if (cmd.hasOption("volc-enabled")) {
Config.setVolcEnabled(true);
}
if (cmd.hasOption("volc-ak")) {
Config.setVolcAccessKey(cmd.getOptionValue("volc-ak"));
}
if (cmd.hasOption("volc-sk")) {
Config.setVolcSecretKey(cmd.getOptionValue("volc-sk"));
}
if (cmd.hasOption("volc-api-key")) {
Config.setVolcApiKey(cmd.getOptionValue("volc-api-key"));
}
if (cmd.hasOption("volc-endpoint")) {
Config.setVolcEndpointId(cmd.getOptionValue("volc-endpoint"));
}
if (cmd.hasOption("volc-kb-ids")) {
Config.setVolcKnowledgeBaseIds(cmd.getOptionValue("volc-kb-ids"));
}
if (cmd.hasOption("volc-topk")) {
Config.setVolcTopK(Integer.parseInt(cmd.getOptionValue("volc-topk")));
}
if (cmd.hasOption("volc-threshold")) {
Config.setVolcThreshold(Double.parseDouble(cmd.getOptionValue("volc-threshold")));
}
}
private static boolean validateConfiguration() {
if (Config.API_APP_ID == null || Config.API_APP_ID.trim().isEmpty()) {
System.err.println("错误:必须设置 app_id");
return false;
}
if (Config.API_ACCESS_KEY == null || Config.API_ACCESS_KEY.trim().isEmpty()) {
System.err.println("错误:必须设置 access_key");
return false;
}
if (Config.loadedTest2EnvPath != null && !Config.loadedTest2EnvPath.isEmpty()) {
System.out.println("当前使用 test2 配置文件: " + Config.loadedTest2EnvPath);
}
System.out.println("当前使用 test2 直连接口: " + Config.BACKEND_BASE_URL);
if (Config.audioFilePath != null && !Config.audioFilePath.isEmpty()) {
java.io.File file = new java.io.File(Config.audioFilePath);
if (!file.exists()) {
System.err.println("错误:音频文件不存在: " + Config.audioFilePath);
return false;
}
}
java.io.File ragFile = new java.io.File(Config.ragFilePath);
if (!ragFile.exists()) {
System.err.println("错误rag 文件不存在: " + Config.ragFilePath);
return false;
}
return true;
}
}

View File

@@ -0,0 +1,224 @@
package com.bigwo.realtimerag;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft_6455;
import org.java_websocket.handshake.ServerHandshake;
import javax.sound.sampled.AudioFormat;
import javax.sound.sampled.AudioSystem;
import javax.sound.sampled.DataLine;
import javax.sound.sampled.LineUnavailableException;
import javax.sound.sampled.SourceDataLine;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
public class NetClient extends WebSocketClient {
private final BlockingQueue<Protocol.Message> incomingMessages = new LinkedBlockingQueue<Protocol.Message>();
private final BlockingQueue<byte[]> audioQueue = new LinkedBlockingQueue<byte[]>();
private volatile boolean isConnected = false;
private volatile boolean shouldStop = false;
private SourceDataLine audioOutputLine;
private Thread audioPlaybackThread;
private volatile String logid;
public NetClient(URI serverUri, Map<String, String> headers) {
super(serverUri, new Draft_6455(), headers, 0);
if (!isAudioFileInput()) {
initializeAudioOutput();
}
}
private void initializeAudioOutput() {
try {
AudioFormat format = new AudioFormat(
Config.OUTPUT_SAMPLE_RATE,
16,
Config.CHANNELS,
true,
false
);
DataLine.Info info = new DataLine.Info(SourceDataLine.class, format);
if (!AudioSystem.isLineSupported(info)) {
System.err.println("不支持音频输出格式");
return;
}
audioOutputLine = (SourceDataLine) AudioSystem.getLine(info);
audioOutputLine.open(format);
audioOutputLine.start();
audioPlaybackThread = new Thread(new Runnable() {
@Override
public void run() {
audioPlaybackLoop();
}
}, "AudioPlayback");
audioPlaybackThread.start();
} catch (LineUnavailableException e) {
System.err.println("音频输出初始化失败: " + e.getMessage());
}
}
private void audioPlaybackLoop() {
while (!shouldStop) {
try {
byte[] audioData = audioQueue.poll(50, TimeUnit.MILLISECONDS);
if (audioData != null && audioOutputLine != null) {
audioOutputLine.write(audioData, 0, audioData.length);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
break;
} catch (Exception e) {
System.err.println("音频播放错误: " + e.getMessage());
}
}
}
@Override
public void onOpen(ServerHandshake handshake) {
isConnected = true;
logid = handshake.getFieldValue("X-Tt-Logid");
System.out.println("WebSocket连接已建立logid=" + (logid == null ? "" : logid));
}
@Override
public void onMessage(String message) {
System.out.println("收到文本消息: " + message);
}
@Override
public void onMessage(ByteBuffer bytes) {
try {
byte[] data = new byte[bytes.remaining()];
bytes.get(data);
Protocol.Message message = Protocol.unmarshal(data);
switch (message.type) {
case FULL_SERVER:
ServerResponseHandler.handleFullServerMessage(this, message);
break;
case AUDIO_ONLY_SERVER:
ServerResponseHandler.handleAudioOnlyServerMessage(this, message);
break;
case ERROR:
ServerResponseHandler.handleErrorMessage(message);
break;
default:
break;
}
incomingMessages.offer(message);
} catch (Exception e) {
System.err.println("处理二进制消息失败: " + e.getMessage());
}
}
public void playAudioData(byte[] audioData) {
if (isAudioFileInput()) {
return;
}
if (audioData == null || audioData.length == 0) {
return;
}
try {
if (Config.PCM_S16LE.equals(Config.pcmFormat)) {
audioQueue.offer(audioData);
return;
}
if (audioData.length % 4 != 0) {
return;
}
int sampleCount = audioData.length / 4;
short[] samples = new short[sampleCount];
ByteBuffer buffer = ByteBuffer.wrap(audioData).order(ByteOrder.LITTLE_ENDIAN);
for (int i = 0; i < sampleCount; i++) {
float sample = buffer.getFloat();
samples[i] = (short) Math.max(-32768, Math.min(32767, sample * 32767.0f));
}
audioQueue.offer(AudioCapture.int16SamplesToBytes(samples));
} catch (Exception e) {
System.err.println("播放音频数据失败: " + e.getMessage());
}
}
@Override
public void onClose(int code, String reason, boolean remote) {
isConnected = false;
cleanup();
System.out.println("WebSocket关闭code=" + code + ", reason=" + reason);
}
@Override
public void onError(Exception ex) {
System.err.println("WebSocket错误: " + ex.getMessage());
}
public boolean isConnected() {
return isConnected;
}
public String getLogid() {
return logid;
}
public Protocol.Message pollIncomingMessage(long timeout, TimeUnit unit) throws InterruptedException {
return incomingMessages.poll(timeout, unit);
}
public void sendAudioData(String sessionId, byte[] audioData) throws IOException {
if (!isConnected) {
throw new IOException("WebSocket未连接");
}
send(Protocol.createAudioMessage(sessionId, audioData));
}
public void sendProtocolMessage(String sessionId, String payload, int eventId) throws IOException {
if (!isConnected) {
throw new IOException("WebSocket未连接");
}
try {
Protocol.Message message = new Protocol.Message();
message.type = Protocol.MsgType.FULL_CLIENT;
message.typeFlag = Protocol.MSG_TYPE_FLAG_WITH_EVENT;
message.event = eventId;
message.sessionId = sessionId;
message.payload = payload.getBytes("UTF-8");
send(Protocol.marshal(message));
} catch (Exception e) {
throw new IOException("发送协议消息失败: " + e.getMessage(), e);
}
}
public void sendChatTextQuery(String sessionId, String text) throws IOException {
if (!isConnected) {
throw new IOException("WebSocket未连接");
}
try {
RequestPayloads.ChatTextQueryPayload payload = new RequestPayloads.ChatTextQueryPayload(text);
String jsonPayload = new com.fasterxml.jackson.databind.ObjectMapper().writeValueAsString(payload);
sendProtocolMessage(sessionId, jsonPayload, 501);
} catch (Exception e) {
throw new IOException("发送ChatTextQuery失败: " + e.getMessage(), e);
}
}
private void cleanup() {
shouldStop = true;
if (audioPlaybackThread != null) {
audioPlaybackThread.interrupt();
}
if (audioOutputLine != null) {
audioOutputLine.drain();
audioOutputLine.stop();
audioOutputLine.close();
}
ServerResponseHandler.saveAudioToPCMFile("output.pcm");
}
private boolean isAudioFileInput() {
return Config.audioFilePath != null && !Config.audioFilePath.isEmpty();
}
}

View File

@@ -0,0 +1,217 @@
package com.bigwo.realtimerag;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
public class Protocol {
private static final ObjectMapper objectMapper = new ObjectMapper();
public enum MsgType {
INVALID(0),
FULL_CLIENT(1),
AUDIO_ONLY_CLIENT(2),
FULL_SERVER(9),
AUDIO_ONLY_SERVER(11),
ERROR(15);
private final int value;
MsgType(int value) {
this.value = value;
}
public int getValue() {
return value;
}
public static MsgType fromBits(int bits) {
for (MsgType type : values()) {
if (type.value == bits) {
return type;
}
}
return INVALID;
}
}
public static final int MSG_TYPE_FLAG_WITH_EVENT = 0b100;
public static final int VERSION_1 = 0x10;
public static final int HEADER_SIZE_4 = 0x1;
public static final int SERIALIZATION_RAW = 0;
public static final int SERIALIZATION_JSON = 0b1 << 4;
public static final int COMPRESSION_NONE = 0;
public static class Message {
public MsgType type = MsgType.INVALID;
public int typeFlag;
public int event;
public String sessionId;
public String connectId;
public byte[] payload;
public long errorCode;
}
public static byte[] marshal(Message msg) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
dos.writeByte(VERSION_1 | HEADER_SIZE_4);
dos.writeByte((msg.type.getValue() << 4) | (msg.typeFlag & 0x0F));
dos.writeByte(SERIALIZATION_JSON | COMPRESSION_NONE);
dos.writeByte(0);
if (containsEvent(msg.typeFlag)) {
dos.writeInt(msg.event);
}
if (shouldWriteSessionId(msg)) {
byte[] sessionIdBytes = msg.sessionId.getBytes("UTF-8");
dos.writeInt(sessionIdBytes.length);
dos.write(sessionIdBytes);
}
if (shouldWriteConnectId(msg)) {
byte[] connectIdBytes = msg.connectId.getBytes("UTF-8");
dos.writeInt(connectIdBytes.length);
dos.write(connectIdBytes);
}
if (msg.type == MsgType.ERROR) {
dos.writeInt((int) msg.errorCode);
}
if (msg.payload != null) {
dos.writeInt(msg.payload.length);
dos.write(msg.payload);
} else {
dos.writeInt(0);
}
return baos.toByteArray();
}
public static Message unmarshal(byte[] data) throws IOException {
if (data.length < 4) {
throw new IOException("数据长度不足");
}
ByteBuffer buf = ByteBuffer.wrap(data).order(ByteOrder.BIG_ENDIAN);
Message msg = new Message();
buf.get();
int typeAndFlag = buf.get() & 0xFF;
buf.get();
buf.get();
int msgTypeBits = (typeAndFlag >> 4) & 0x0F;
msg.type = MsgType.fromBits(msgTypeBits);
msg.typeFlag = typeAndFlag & 0x0F;
if (containsEvent(msg.typeFlag)) {
msg.event = buf.getInt();
}
if (shouldReadSessionId(msg)) {
int size = buf.getInt();
if (size > 0) {
byte[] bytes = new byte[size];
buf.get(bytes);
msg.sessionId = new String(bytes, "UTF-8");
}
}
if (shouldReadConnectId(msg)) {
int size = buf.getInt();
if (size > 0) {
byte[] bytes = new byte[size];
buf.get(bytes);
msg.connectId = new String(bytes, "UTF-8");
}
}
if (msg.type == MsgType.ERROR) {
msg.errorCode = buf.getInt() & 0xFFFFFFFFL;
}
if (buf.remaining() >= 4) {
int payloadSize = buf.getInt();
if (payloadSize > 0 && buf.remaining() >= payloadSize) {
msg.payload = new byte[payloadSize];
buf.get(msg.payload);
}
}
return msg;
}
private static boolean containsEvent(int typeFlag) {
return (typeFlag & MSG_TYPE_FLAG_WITH_EVENT) == MSG_TYPE_FLAG_WITH_EVENT;
}
private static boolean shouldWriteSessionId(Message msg) {
return containsEvent(msg.typeFlag) && msg.event != 1 && msg.event != 50 && msg.event != 51 && msg.event != 52;
}
private static boolean shouldReadSessionId(Message msg) {
return containsEvent(msg.typeFlag) && msg.event != 1 && msg.event != 50 && msg.event != 51 && msg.event != 52;
}
private static boolean shouldWriteConnectId(Message msg) {
return containsEvent(msg.typeFlag) && (msg.event == 50 || msg.event == 51 || msg.event == 52);
}
private static boolean shouldReadConnectId(Message msg) {
return containsEvent(msg.typeFlag) && (msg.event == 50 || msg.event == 51 || msg.event == 52);
}
public static byte[] createAudioMessage(String sessionId, byte[] audioData) throws IOException {
Message msg = new Message();
msg.type = MsgType.AUDIO_ONLY_CLIENT;
msg.typeFlag = MSG_TYPE_FLAG_WITH_EVENT;
msg.event = 200;
msg.sessionId = sessionId;
msg.payload = audioData;
return marshalRawAudio(msg);
}
private static byte[] marshalRawAudio(Message message) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
dos.writeByte(VERSION_1 | HEADER_SIZE_4);
dos.writeByte((message.type.getValue() << 4) | (message.typeFlag & 0x0F));
dos.writeByte(SERIALIZATION_RAW | COMPRESSION_NONE);
dos.writeByte(0);
if (containsEvent(message.typeFlag)) {
dos.writeInt(message.event);
}
if (shouldWriteSessionId(message)) {
byte[] sessionIdBytes = message.sessionId.getBytes("UTF-8");
dos.writeInt(sessionIdBytes.length);
dos.write(sessionIdBytes);
}
if (message.payload != null) {
dos.writeInt(message.payload.length);
dos.write(message.payload);
} else {
dos.writeInt(0);
}
return baos.toByteArray();
}
public static byte[] createFullClientMessage(String sessionId, String payloadJson) throws IOException {
Message message = new Message();
message.type = MsgType.FULL_CLIENT;
message.typeFlag = MSG_TYPE_FLAG_WITH_EVENT;
message.sessionId = sessionId;
message.payload = payloadJson.getBytes("UTF-8");
return marshal(message);
}
public static String generateSessionId() {
return UUID.randomUUID().toString();
}
public static byte[] createJsonPayloadWithSpeaker(String sessionId, String text) throws IOException {
ObjectNode root = objectMapper.createObjectNode();
root.put("session_id", sessionId);
root.put("text", text);
root.put("speaker", Config.DEFAULT_SPEAKER);
return createFullClientMessage(sessionId, objectMapper.writeValueAsString(root));
}
}

Some files were not shown because too many files have changed in this diff Show More