知识库rerank设置

This commit is contained in:
2025-11-07 11:21:27 +08:00
parent d9947e273c
commit b98450df96
8 changed files with 360 additions and 34 deletions

View File

@@ -42,6 +42,7 @@ CREATE TABLE `tb_ai_knowledge` (
`embedding_model_provider` VARCHAR(100) DEFAULT NULL COMMENT '向量模型提供商',
`rerank_model` VARCHAR(100) DEFAULT NULL COMMENT 'Rerank模型名称',
`rerank_model_provider` VARCHAR(100) DEFAULT NULL COMMENT 'Rerank模型提供商',
`reranking_enable` TINYINT(1) DEFAULT 0 COMMENT '是否启用Rerank0否 1是',
`retrieval_top_k` INT(11) DEFAULT 2 COMMENT '检索Top K返回前K个结果',
`retrieval_score_threshold` DECIMAL(3,2) DEFAULT 0.00 COMMENT '检索分数阈值0.00-1.00',
`vector_id` VARCHAR(100) DEFAULT NULL COMMENT '向量ID用于向量检索',

View File

@@ -19,24 +19,24 @@ public class RetrievalModel {
@JSONField(name = "search_method")
private String searchMethod;
/**
* Rerank模型提供商
*/
@JSONField(name = "reranking_provider_name")
private String rerankingProviderName;
/**
* Rerank模型名称
*/
@JSONField(name = "reranking_model")
private String rerankingModel;
/**
* Rerank是否启用
*/
@JSONField(name = "reranking_enable")
private Boolean rerankingEnable;
/**
* Rerank模式字符串值为 "reranking_model"
*/
@JSONField(name = "reranking_mode")
private String rerankingMode;
/**
* Rerank模型配置当 reranking_enable=true 时必须设置)
*/
@JSONField(name = "reranking_model")
private RerankingModel rerankingModel;
/**
* Top K返回前K个结果
*/
@@ -54,5 +54,23 @@ public class RetrievalModel {
*/
@JSONField(name = "score_threshold_enabled")
private Boolean scoreThresholdEnabled;
/**
* Rerank模型配置嵌套对象
*/
@Data
public static class RerankingModel {
/**
* Rerank模型提供商
*/
@JSONField(name = "reranking_provider_name")
private String rerankingProviderName;
/**
* Rerank模型名称
*/
@JSONField(name = "reranking_model_name")
private String rerankingModelName;
}
}

View File

@@ -29,6 +29,8 @@ import org.xyzh.common.dto.user.TbSysUser;
import org.xyzh.common.vo.UserDeptRoleVO;
import org.xyzh.system.utils.LoginUtil;
import com.alibaba.fastjson2.JSON;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
@@ -118,7 +120,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
// 设置检索模型配置Rerank、Top K、Score 阈值)
RetrievalModel retrievalModel = new RetrievalModel();
retrievalModel.setSearchMethod("hybrid_search"); // 默认使用混合搜索
retrievalModel.setSearchMethod("hybrid_search"); // 必填字段
// Top K 配置
if (knowledge.getRetrievalTopK() != null && knowledge.getRetrievalTopK() > 0) {
@@ -136,20 +138,40 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
retrievalModel.setScoreThresholdEnabled(false);
}
// Rerank 模型配置
if (StringUtils.hasText(knowledge.getRerankModel())) {
retrievalModel.setRerankingEnable(true);
retrievalModel.setRerankingModel(knowledge.getRerankModel());
retrievalModel.setRerankingProviderName(knowledge.getRerankModelProvider());
log.info("创建知识库 - 启用Rerank: model={}, provider={}",
knowledge.getRerankModel(), knowledge.getRerankModelProvider());
// Rerank 模型配置(以前端传参为准)
Boolean rerankEnable = knowledge.getRerankingEnable() != null ?
knowledge.getRerankingEnable() : false;
retrievalModel.setRerankingEnable(rerankEnable);
if (rerankEnable) {
// 启用 Rerank 时model 和 provider 必须有值
if (!StringUtils.hasText(knowledge.getRerankModel())) {
throw new IllegalArgumentException("启用Rerank后必须指定rerankModel");
}
if (!StringUtils.hasText(knowledge.getRerankModelProvider())) {
throw new IllegalArgumentException("启用Rerank后必须指定rerankModelProvider");
}
// 设置 reranking_mode 为固定值 "reranking_model"
retrievalModel.setRerankingMode("reranking_model");
// 创建 RerankingModel 对象(嵌套在 reranking_model 字段中)
RetrievalModel.RerankingModel rerankingModel = new RetrievalModel.RerankingModel();
rerankingModel.setRerankingProviderName(knowledge.getRerankModelProvider());
rerankingModel.setRerankingModelName(knowledge.getRerankModel());
retrievalModel.setRerankingModel(rerankingModel);
log.info("创建知识库 - 启用Rerank: enable={}, mode=reranking_model, model={}, provider={}",
rerankEnable, knowledge.getRerankModel(), knowledge.getRerankModelProvider());
} else {
retrievalModel.setRerankingEnable(false);
// 禁用 Rerank不设置 rerankingMode 和 rerankingModel
log.info("创建知识库 - 禁用Rerank");
}
difyRequest.setRetrievalModel(retrievalModel);
// 调用Dify API创建知识库使用知识库API Key
log.info("创建知识库 - 请求参数: {}", JSON.toJSONString(difyRequest));
DatasetCreateResponse difyResponse = difyApiClient.createDataset(difyRequest);
difyDatasetId = difyResponse.getId();
@@ -300,6 +322,82 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
}
needUpdateDify = true;
}
// 检索配置变化Rerank、Top K、Score阈值
boolean retrievalConfigChanged = false;
// 检测 Rerank 开关状态变化
Boolean newRerankEnable = knowledge.getRerankingEnable();
Boolean existingRerankEnable = existing.getRerankingEnable();
boolean rerankEnableChanged = (newRerankEnable != null && !newRerankEnable.equals(existingRerankEnable));
// 检测 Rerank 模型变化
String newRerankModel = knowledge.getRerankModel();
String existingRerankModel = existing.getRerankModel();
boolean rerankModelChanged = (newRerankModel != null && !newRerankModel.equals(existingRerankModel));
if (rerankEnableChanged || rerankModelChanged ||
(knowledge.getRetrievalTopK() != null && !knowledge.getRetrievalTopK().equals(existing.getRetrievalTopK())) ||
(knowledge.getRetrievalScoreThreshold() != null && !knowledge.getRetrievalScoreThreshold().equals(existing.getRetrievalScoreThreshold()))) {
retrievalConfigChanged = true;
}
if (retrievalConfigChanged) {
RetrievalModel retrievalModel = new RetrievalModel();
retrievalModel.setSearchMethod("hybrid_search"); // 必填字段
// Top K
if (knowledge.getRetrievalTopK() != null) {
retrievalModel.setTopK(knowledge.getRetrievalTopK());
} else {
retrievalModel.setTopK(existing.getRetrievalTopK() != null ? existing.getRetrievalTopK() : 2);
}
// Score 阈值
Double scoreThreshold = knowledge.getRetrievalScoreThreshold() != null ?
knowledge.getRetrievalScoreThreshold() :
(existing.getRetrievalScoreThreshold() != null ? existing.getRetrievalScoreThreshold() : 0.0);
retrievalModel.setScoreThreshold(scoreThreshold);
retrievalModel.setScoreThresholdEnabled(scoreThreshold > 0);
// Rerank 配置(以前端传参为准)
Boolean finalRerankEnable = newRerankEnable != null ? newRerankEnable :
(existingRerankEnable != null ? existingRerankEnable : false);
String finalRerankModel = newRerankModel != null ? newRerankModel : existingRerankModel;
String finalRerankProvider = knowledge.getRerankModelProvider() != null ?
knowledge.getRerankModelProvider() : existing.getRerankModelProvider();
// 直接使用前端传入的开关状态
retrievalModel.setRerankingEnable(finalRerankEnable);
if (finalRerankEnable) {
// 启用 Rerank 时model 和 provider 必须有值
if (!StringUtils.hasText(finalRerankModel)) {
throw new IllegalArgumentException("启用Rerank后必须指定rerankModel");
}
if (!StringUtils.hasText(finalRerankProvider)) {
throw new IllegalArgumentException("启用Rerank后必须指定rerankModelProvider");
}
// 设置 reranking_mode 为固定值 "reranking_model"
retrievalModel.setRerankingMode("reranking_model");
// 创建 RerankingModel 对象(嵌套在 reranking_model 字段中)
RetrievalModel.RerankingModel rerankingModel = new RetrievalModel.RerankingModel();
rerankingModel.setRerankingProviderName(finalRerankProvider);
rerankingModel.setRerankingModelName(finalRerankModel);
retrievalModel.setRerankingModel(rerankingModel);
log.info("更新Rerank配置: 启用 - enable={}, mode=reranking_model, model={}, provider={}",
finalRerankEnable, finalRerankModel, finalRerankProvider);
} else {
// 禁用 Rerank不设置 rerankingMode 和 rerankingModel
log.info("更新Rerank配置: 禁用 - enable={}", finalRerankEnable);
}
updateRequest.setRetrievalModel(retrievalModel);
needUpdateDify = true;
}
// 同步到Dify
if (needUpdateDify && StringUtils.hasText(existing.getDifyDatasetId())) {

View File

@@ -21,6 +21,7 @@
<result column="embedding_model_provider" property="embeddingModelProvider" jdbcType="VARCHAR"/>
<result column="rerank_model" property="rerankModel" jdbcType="VARCHAR"/>
<result column="rerank_model_provider" property="rerankModelProvider" jdbcType="VARCHAR"/>
<result column="reranking_enable" property="rerankingEnable" jdbcType="BOOLEAN"/>
<result column="retrieval_top_k" property="retrievalTopK" jdbcType="INTEGER"/>
<result column="retrieval_score_threshold" property="retrievalScoreThreshold" jdbcType="DECIMAL"/>
<result column="vector_id" property="vectorID" jdbcType="VARCHAR"/>
@@ -40,7 +41,7 @@
<sql id="Base_Column_List">
id, title, avatar, description, content, source_type, source_id, file_name, file_path,
category, tags, dify_dataset_id, dify_indexing_technique, embedding_model, embedding_model_provider,
rerank_model, rerank_model_provider, retrieval_top_k, retrieval_score_threshold,
rerank_model, rerank_model_provider, reranking_enable, retrieval_top_k, retrieval_score_threshold,
vector_id, document_count, total_chunks, status, creator, creator_dept,
updater, create_time, update_time, delete_time, deleted
</sql>
@@ -170,13 +171,13 @@
INSERT INTO tb_ai_knowledge (
id, title, avatar, description, content, source_type, source_id, file_name, file_path,
category, tags, dify_dataset_id, dify_indexing_technique, embedding_model, embedding_model_provider,
rerank_model, rerank_model_provider, retrieval_top_k, retrieval_score_threshold,
rerank_model, rerank_model_provider, reranking_enable, retrieval_top_k, retrieval_score_threshold,
vector_id, document_count, total_chunks, status, creator, creator_dept,
updater, create_time, update_time, deleted
) VALUES (
#{ID}, #{title}, #{avatar}, #{description}, #{content}, #{sourceType}, #{sourceID}, #{fileName}, #{filePath},
#{category}, #{tags}, #{difyDatasetId}, #{difyIndexingTechnique}, #{embeddingModel}, #{embeddingModelProvider},
#{rerankModel}, #{rerankModelProvider}, #{retrievalTopK}, #{retrievalScoreThreshold},
#{rerankModel}, #{rerankModelProvider}, #{rerankingEnable}, #{retrievalTopK}, #{retrievalScoreThreshold},
#{vectorID}, #{documentCount}, #{totalChunks}, #{status}, #{creator}, #{creatorDept},
#{updater}, #{createTime}, #{updateTime}, #{deleted}
)
@@ -202,6 +203,7 @@
<if test="embeddingModelProvider != null">embedding_model_provider = #{embeddingModelProvider},</if>
<if test="rerankModel != null">rerank_model = #{rerankModel},</if>
<if test="rerankModelProvider != null">rerank_model_provider = #{rerankModelProvider},</if>
<if test="rerankingEnable != null">reranking_enable = #{rerankingEnable},</if>
<if test="retrievalTopK != null">retrieval_top_k = #{retrievalTopK},</if>
<if test="retrievalScoreThreshold != null">retrieval_score_threshold = #{retrievalScoreThreshold},</if>
<if test="vectorID != null">vector_id = #{vectorID},</if>

View File

@@ -93,6 +93,11 @@ public class TbAiKnowledge extends BaseDTO {
*/
private String rerankModelProvider;
/**
* @description 是否启用Rerank
*/
private Boolean rerankingEnable;
/**
* @description 检索Top K返回前K个结果
*/
@@ -322,6 +327,14 @@ public class TbAiKnowledge extends BaseDTO {
this.rerankModelProvider = rerankModelProvider;
}
public Boolean getRerankingEnable() {
return rerankingEnable;
}
public void setRerankingEnable(Boolean rerankingEnable) {
this.rerankingEnable = rerankingEnable;
}
public Integer getRetrievalTopK() {
return retrievalTopK;
}