dify
This commit is contained in:
@@ -0,0 +1,266 @@
|
||||
import {
|
||||
useCallback,
|
||||
} from 'react'
|
||||
import { produce } from 'immer'
|
||||
import { useStoreApi } from 'reactflow'
|
||||
import { useNodeDataUpdate } from '@/app/components/workflow/hooks'
|
||||
import type { ValueSelector } from '@/app/components/workflow/types'
|
||||
import {
|
||||
ChunkStructureEnum,
|
||||
IndexMethodEnum,
|
||||
RetrievalSearchMethodEnum,
|
||||
WeightedScoreEnum,
|
||||
} from '../types'
|
||||
import type {
|
||||
KnowledgeBaseNodeType,
|
||||
RerankingModel,
|
||||
} from '../types'
|
||||
import {
|
||||
HybridSearchModeEnum,
|
||||
} from '../types'
|
||||
import { isHighQualitySearchMethod } from '../utils'
|
||||
import { DEFAULT_WEIGHTED_SCORE, RerankingModeEnum } from '@/models/datasets'
|
||||
|
||||
export const useConfig = (id: string) => {
|
||||
const store = useStoreApi()
|
||||
const { handleNodeDataUpdateWithSyncDraft } = useNodeDataUpdate()
|
||||
|
||||
const getNodeData = useCallback(() => {
|
||||
const { getNodes } = store.getState()
|
||||
const nodes = getNodes()
|
||||
|
||||
return nodes.find(node => node.id === id)
|
||||
}, [store, id])
|
||||
|
||||
const handleNodeDataUpdate = useCallback((data: Partial<KnowledgeBaseNodeType>) => {
|
||||
handleNodeDataUpdateWithSyncDraft({
|
||||
id,
|
||||
data,
|
||||
})
|
||||
}, [id, handleNodeDataUpdateWithSyncDraft])
|
||||
|
||||
const getDefaultWeights = useCallback(({
|
||||
embeddingModel,
|
||||
embeddingModelProvider,
|
||||
}: {
|
||||
embeddingModel: string
|
||||
embeddingModelProvider: string
|
||||
}) => {
|
||||
return {
|
||||
vector_setting: {
|
||||
vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic,
|
||||
embedding_provider_name: embeddingModelProvider || '',
|
||||
embedding_model_name: embeddingModel,
|
||||
},
|
||||
keyword_setting: {
|
||||
keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword,
|
||||
},
|
||||
}
|
||||
}, [])
|
||||
|
||||
const handleChunkStructureChange = useCallback((chunkStructure: ChunkStructureEnum) => {
|
||||
const nodeData = getNodeData()
|
||||
const {
|
||||
indexing_technique,
|
||||
retrieval_model,
|
||||
chunk_structure,
|
||||
index_chunk_variable_selector,
|
||||
} = nodeData?.data || {}
|
||||
const { search_method } = retrieval_model || {}
|
||||
handleNodeDataUpdate({
|
||||
chunk_structure: chunkStructure,
|
||||
indexing_technique: (chunkStructure === ChunkStructureEnum.parent_child || chunkStructure === ChunkStructureEnum.question_answer) ? IndexMethodEnum.QUALIFIED : indexing_technique,
|
||||
retrieval_model: {
|
||||
...retrieval_model,
|
||||
search_method: ((chunkStructure === ChunkStructureEnum.parent_child || chunkStructure === ChunkStructureEnum.question_answer) && !isHighQualitySearchMethod(search_method)) ? RetrievalSearchMethodEnum.keywordSearch : search_method,
|
||||
},
|
||||
index_chunk_variable_selector: chunkStructure === chunk_structure ? index_chunk_variable_selector : [],
|
||||
})
|
||||
}, [handleNodeDataUpdate, getNodeData])
|
||||
|
||||
const handleIndexMethodChange = useCallback((indexMethod: IndexMethodEnum) => {
|
||||
const nodeData = getNodeData()
|
||||
|
||||
handleNodeDataUpdate(produce(nodeData?.data as KnowledgeBaseNodeType, (draft) => {
|
||||
draft.indexing_technique = indexMethod
|
||||
|
||||
if (indexMethod === IndexMethodEnum.ECONOMICAL)
|
||||
draft.retrieval_model.search_method = RetrievalSearchMethodEnum.keywordSearch
|
||||
else if (indexMethod === IndexMethodEnum.QUALIFIED)
|
||||
draft.retrieval_model.search_method = RetrievalSearchMethodEnum.semantic
|
||||
}))
|
||||
}, [handleNodeDataUpdate, getNodeData])
|
||||
|
||||
const handleKeywordNumberChange = useCallback((keywordNumber: number) => {
|
||||
handleNodeDataUpdate({ keyword_number: keywordNumber })
|
||||
}, [handleNodeDataUpdate])
|
||||
|
||||
const handleEmbeddingModelChange = useCallback(({
|
||||
embeddingModel,
|
||||
embeddingModelProvider,
|
||||
}: {
|
||||
embeddingModel: string
|
||||
embeddingModelProvider: string
|
||||
}) => {
|
||||
const nodeData = getNodeData()
|
||||
const defaultWeights = getDefaultWeights({
|
||||
embeddingModel,
|
||||
embeddingModelProvider,
|
||||
})
|
||||
const changeData = {
|
||||
embedding_model: embeddingModel,
|
||||
embedding_model_provider: embeddingModelProvider,
|
||||
retrieval_model: {
|
||||
...nodeData?.data.retrieval_model,
|
||||
},
|
||||
}
|
||||
if (changeData.retrieval_model.weights) {
|
||||
changeData.retrieval_model = {
|
||||
...changeData.retrieval_model,
|
||||
weights: {
|
||||
...changeData.retrieval_model.weights,
|
||||
vector_setting: {
|
||||
...changeData.retrieval_model.weights.vector_setting,
|
||||
embedding_provider_name: embeddingModelProvider,
|
||||
embedding_model_name: embeddingModel,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
else {
|
||||
changeData.retrieval_model = {
|
||||
...changeData.retrieval_model,
|
||||
weights: defaultWeights,
|
||||
}
|
||||
}
|
||||
handleNodeDataUpdate(changeData)
|
||||
}, [getNodeData, getDefaultWeights, handleNodeDataUpdate])
|
||||
|
||||
const handleRetrievalSearchMethodChange = useCallback((searchMethod: RetrievalSearchMethodEnum) => {
|
||||
const nodeData = getNodeData()
|
||||
const changeData = {
|
||||
retrieval_model: {
|
||||
...nodeData?.data.retrieval_model,
|
||||
search_method: searchMethod,
|
||||
reranking_mode: nodeData?.data.retrieval_model.reranking_mode || RerankingModeEnum.RerankingModel,
|
||||
},
|
||||
}
|
||||
if (searchMethod === RetrievalSearchMethodEnum.hybrid) {
|
||||
changeData.retrieval_model = {
|
||||
...changeData.retrieval_model,
|
||||
reranking_enable: changeData.retrieval_model.reranking_mode === RerankingModeEnum.RerankingModel,
|
||||
}
|
||||
}
|
||||
handleNodeDataUpdate(changeData)
|
||||
}, [getNodeData, handleNodeDataUpdate])
|
||||
|
||||
const handleHybridSearchModeChange = useCallback((hybridSearchMode: HybridSearchModeEnum) => {
|
||||
const nodeData = getNodeData()
|
||||
const defaultWeights = getDefaultWeights({
|
||||
embeddingModel: nodeData?.data.embedding_model || '',
|
||||
embeddingModelProvider: nodeData?.data.embedding_model_provider || '',
|
||||
})
|
||||
handleNodeDataUpdate({
|
||||
retrieval_model: {
|
||||
...nodeData?.data.retrieval_model,
|
||||
reranking_mode: hybridSearchMode,
|
||||
reranking_enable: hybridSearchMode === HybridSearchModeEnum.RerankingModel,
|
||||
weights: nodeData?.data.retrieval_model.weights || defaultWeights,
|
||||
},
|
||||
})
|
||||
}, [getNodeData, getDefaultWeights, handleNodeDataUpdate])
|
||||
|
||||
const handleRerankingModelEnabledChange = useCallback((rerankingModelEnabled: boolean) => {
|
||||
const nodeData = getNodeData()
|
||||
handleNodeDataUpdate({
|
||||
retrieval_model: {
|
||||
...nodeData?.data.retrieval_model,
|
||||
reranking_enable: rerankingModelEnabled,
|
||||
},
|
||||
})
|
||||
}, [getNodeData, handleNodeDataUpdate])
|
||||
|
||||
const handleWeighedScoreChange = useCallback((weightedScore: { value: number[] }) => {
|
||||
const nodeData = getNodeData()
|
||||
handleNodeDataUpdate({
|
||||
retrieval_model: {
|
||||
...nodeData?.data.retrieval_model,
|
||||
weights: {
|
||||
weight_type: WeightedScoreEnum.Customized,
|
||||
vector_setting: {
|
||||
...nodeData?.data.retrieval_model.weights?.vector_setting,
|
||||
vector_weight: weightedScore.value[0],
|
||||
},
|
||||
keyword_setting: {
|
||||
keyword_weight: weightedScore.value[1],
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}, [getNodeData, handleNodeDataUpdate])
|
||||
|
||||
const handleRerankingModelChange = useCallback((rerankingModel: RerankingModel) => {
|
||||
const nodeData = getNodeData()
|
||||
handleNodeDataUpdate({
|
||||
retrieval_model: {
|
||||
...nodeData?.data.retrieval_model,
|
||||
reranking_model: {
|
||||
reranking_provider_name: rerankingModel.reranking_provider_name,
|
||||
reranking_model_name: rerankingModel.reranking_model_name,
|
||||
},
|
||||
},
|
||||
})
|
||||
}, [getNodeData, handleNodeDataUpdate])
|
||||
|
||||
const handleTopKChange = useCallback((topK: number) => {
|
||||
const nodeData = getNodeData()
|
||||
handleNodeDataUpdate({
|
||||
retrieval_model: {
|
||||
...nodeData?.data.retrieval_model,
|
||||
top_k: topK,
|
||||
},
|
||||
})
|
||||
}, [getNodeData, handleNodeDataUpdate])
|
||||
|
||||
const handleScoreThresholdChange = useCallback((scoreThreshold: number) => {
|
||||
const nodeData = getNodeData()
|
||||
handleNodeDataUpdate({
|
||||
retrieval_model: {
|
||||
...nodeData?.data.retrieval_model,
|
||||
score_threshold: scoreThreshold,
|
||||
},
|
||||
})
|
||||
}, [getNodeData, handleNodeDataUpdate])
|
||||
|
||||
const handleScoreThresholdEnabledChange = useCallback((isEnabled: boolean) => {
|
||||
const nodeData = getNodeData()
|
||||
handleNodeDataUpdate({
|
||||
retrieval_model: {
|
||||
...nodeData?.data.retrieval_model,
|
||||
score_threshold_enabled: isEnabled,
|
||||
},
|
||||
})
|
||||
}, [getNodeData, handleNodeDataUpdate])
|
||||
|
||||
const handleInputVariableChange = useCallback((inputVariable: string | ValueSelector) => {
|
||||
handleNodeDataUpdate({
|
||||
index_chunk_variable_selector: Array.isArray(inputVariable) ? inputVariable : [],
|
||||
})
|
||||
}, [handleNodeDataUpdate])
|
||||
|
||||
return {
|
||||
handleChunkStructureChange,
|
||||
handleIndexMethodChange,
|
||||
handleKeywordNumberChange,
|
||||
handleEmbeddingModelChange,
|
||||
handleRetrievalSearchMethodChange,
|
||||
handleHybridSearchModeChange,
|
||||
handleRerankingModelEnabledChange,
|
||||
handleWeighedScoreChange,
|
||||
handleRerankingModelChange,
|
||||
handleTopKChange,
|
||||
handleScoreThresholdChange,
|
||||
handleScoreThresholdEnabledChange,
|
||||
handleInputVariableChange,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
IndexMethodEnum,
|
||||
RetrievalSearchMethodEnum,
|
||||
} from '../types'
|
||||
|
||||
export const useSettingsDisplay = () => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return {
|
||||
[IndexMethodEnum.QUALIFIED]: t('datasetCreation.stepTwo.qualified'),
|
||||
[IndexMethodEnum.ECONOMICAL]: t('datasetSettings.form.indexMethodEconomy'),
|
||||
[RetrievalSearchMethodEnum.semantic]: t('dataset.retrieval.semantic_search.title'),
|
||||
[RetrievalSearchMethodEnum.fullText]: t('dataset.retrieval.full_text_search.title'),
|
||||
[RetrievalSearchMethodEnum.hybrid]: t('dataset.retrieval.hybrid_search.title'),
|
||||
[RetrievalSearchMethodEnum.keywordSearch]: t('dataset.retrieval.keyword_search.title'),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user