dify
This commit is contained in:
@@ -0,0 +1,41 @@
|
||||
'use client'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback } from 'react'
|
||||
import AddButton from '@/app/components/base/button/add-button'
|
||||
import SelectDataset from '@/app/components/app/configuration/dataset-config/select-dataset'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
|
||||
type Props = {
|
||||
selectedIds: string[]
|
||||
onChange: (dataSets: DataSet[]) => void
|
||||
}
|
||||
|
||||
const AddDataset: FC<Props> = ({
|
||||
selectedIds,
|
||||
onChange,
|
||||
}) => {
|
||||
const [isShowModal, {
|
||||
setTrue: showModal,
|
||||
setFalse: hideModal,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const handleSelect = useCallback((datasets: DataSet[]) => {
|
||||
onChange(datasets)
|
||||
hideModal()
|
||||
}, [onChange, hideModal])
|
||||
return (
|
||||
<div>
|
||||
<AddButton onClick={showModal} />
|
||||
{isShowModal && (
|
||||
<SelectDataset
|
||||
isShow={isShowModal}
|
||||
onClose={hideModal}
|
||||
selectedIds={selectedIds}
|
||||
onSelect={handleSelect}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(AddDataset)
|
||||
@@ -0,0 +1,126 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback, useState } from 'react'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import {
|
||||
RiDeleteBinLine,
|
||||
RiEditLine,
|
||||
} from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import ActionButton, { ActionButtonState } from '@/app/components/base/action-button'
|
||||
import SettingsModal from '@/app/components/app/configuration/dataset-config/settings-modal'
|
||||
import Drawer from '@/app/components/base/drawer'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import { useKnowledge } from '@/hooks/use-knowledge'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
|
||||
type Props = {
|
||||
payload: DataSet
|
||||
onRemove: () => void
|
||||
onChange: (dataSet: DataSet) => void
|
||||
readonly?: boolean
|
||||
editable?: boolean
|
||||
}
|
||||
|
||||
const DatasetItem: FC<Props> = ({
|
||||
payload,
|
||||
onRemove,
|
||||
onChange,
|
||||
readonly,
|
||||
editable = true,
|
||||
}) => {
|
||||
const media = useBreakpoints()
|
||||
const { t } = useTranslation()
|
||||
const isMobile = media === MediaType.mobile
|
||||
const { formatIndexingTechniqueAndMethod } = useKnowledge()
|
||||
const [isDeleteHovered, setIsDeleteHovered] = useState(false)
|
||||
|
||||
const [isShowSettingsModal, {
|
||||
setTrue: showSettingsModal,
|
||||
setFalse: hideSettingsModal,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const handleSave = useCallback((newDataset: DataSet) => {
|
||||
onChange(newDataset)
|
||||
hideSettingsModal()
|
||||
}, [hideSettingsModal, onChange])
|
||||
|
||||
const handleRemove = useCallback((e: React.MouseEvent) => {
|
||||
e.stopPropagation()
|
||||
onRemove()
|
||||
}, [onRemove])
|
||||
|
||||
const iconInfo = payload.icon_info || {
|
||||
icon: '📙',
|
||||
icon_type: 'emoji',
|
||||
icon_background: '#FFF4ED',
|
||||
icon_url: '',
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={`group/dataset-item flex h-10 cursor-pointer items-center justify-between rounded-lg
|
||||
border-[0.5px] border-components-panel-border-subtle px-2
|
||||
${isDeleteHovered
|
||||
? 'border-state-destructive-border bg-state-destructive-hover'
|
||||
: 'bg-components-panel-on-panel-item-bg hover:bg-components-panel-on-panel-item-bg-hover'
|
||||
}`}>
|
||||
<div className='flex w-0 grow items-center space-x-1.5'>
|
||||
<AppIcon
|
||||
size='tiny'
|
||||
iconType={iconInfo.icon_type}
|
||||
icon={iconInfo.icon}
|
||||
background={iconInfo.icon_type === 'image' ? undefined : iconInfo.icon_background}
|
||||
imageUrl={iconInfo.icon_type === 'image' ? iconInfo.icon_url : undefined}
|
||||
/>
|
||||
<div className='system-sm-medium w-0 grow truncate text-text-secondary'>{payload.name}</div>
|
||||
</div>
|
||||
{!readonly && (
|
||||
<div className='ml-2 hidden shrink-0 items-center space-x-1 group-hover/dataset-item:flex'>
|
||||
{
|
||||
editable && <ActionButton
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
showSettingsModal()
|
||||
}}
|
||||
>
|
||||
<RiEditLine className='h-4 w-4 shrink-0 text-text-tertiary' />
|
||||
</ActionButton>
|
||||
}
|
||||
<ActionButton
|
||||
onClick={handleRemove}
|
||||
state={isDeleteHovered ? ActionButtonState.Destructive : ActionButtonState.Default}
|
||||
onMouseEnter={() => setIsDeleteHovered(true)}
|
||||
onMouseLeave={() => setIsDeleteHovered(false)}
|
||||
>
|
||||
<RiDeleteBinLine className={`h-4 w-4 shrink-0 ${isDeleteHovered ? 'text-text-destructive' : 'text-text-tertiary'}`} />
|
||||
</ActionButton>
|
||||
</div>
|
||||
)}
|
||||
{
|
||||
payload.indexing_technique && <Badge
|
||||
className='shrink-0 group-hover/dataset-item:hidden'
|
||||
text={formatIndexingTechniqueAndMethod(payload.indexing_technique, payload.retrieval_model_dict?.search_method)}
|
||||
/>
|
||||
}
|
||||
{
|
||||
payload.provider === 'external' && <Badge
|
||||
className='shrink-0 group-hover/dataset-item:hidden'
|
||||
text={t('dataset.externalTag') as string}
|
||||
/>
|
||||
}
|
||||
|
||||
{isShowSettingsModal && (
|
||||
<Drawer isOpen={isShowSettingsModal} onClose={hideSettingsModal} footer={null} mask={isMobile} panelClassName='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'>
|
||||
<SettingsModal
|
||||
currentDataset={payload}
|
||||
onCancel={hideSettingsModal}
|
||||
onSave={handleSave}
|
||||
/>
|
||||
</Drawer>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(DatasetItem)
|
||||
@@ -0,0 +1,82 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback, useMemo } from 'react'
|
||||
import { produce } from 'immer'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Item from './dataset-item'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { useSelector as useAppContextSelector } from '@/context/app-context'
|
||||
import { hasEditPermissionForDataset } from '@/utils/permission'
|
||||
|
||||
type Props = {
|
||||
list: DataSet[]
|
||||
onChange: (list: DataSet[]) => void
|
||||
readonly?: boolean
|
||||
}
|
||||
|
||||
const DatasetList: FC<Props> = ({
|
||||
list,
|
||||
onChange,
|
||||
readonly,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const userProfile = useAppContextSelector(s => s.userProfile)
|
||||
|
||||
const handleRemove = useCallback((index: number) => {
|
||||
return () => {
|
||||
const newList = produce(list, (draft) => {
|
||||
draft.splice(index, 1)
|
||||
})
|
||||
onChange(newList)
|
||||
}
|
||||
}, [list, onChange])
|
||||
|
||||
const handleChange = useCallback((index: number) => {
|
||||
return (value: DataSet) => {
|
||||
const newList = produce(list, (draft) => {
|
||||
draft[index] = value
|
||||
})
|
||||
onChange(newList)
|
||||
}
|
||||
}, [list, onChange])
|
||||
|
||||
const formattedList = useMemo(() => {
|
||||
return list.map((item) => {
|
||||
const datasetConfig = {
|
||||
createdBy: item.created_by,
|
||||
partialMemberList: item.partial_member_list || [],
|
||||
permission: item.permission,
|
||||
}
|
||||
return {
|
||||
...item,
|
||||
editable: hasEditPermissionForDataset(userProfile?.id || '', datasetConfig),
|
||||
}
|
||||
})
|
||||
}, [list, userProfile?.id])
|
||||
|
||||
return (
|
||||
<div className='space-y-1'>
|
||||
{formattedList.length
|
||||
? formattedList.map((item, index) => {
|
||||
return (
|
||||
<Item
|
||||
key={index}
|
||||
payload={item}
|
||||
onRemove={handleRemove(index)}
|
||||
onChange={handleChange(index)}
|
||||
readonly={readonly}
|
||||
editable={item.editable}
|
||||
/>
|
||||
)
|
||||
})
|
||||
: (
|
||||
<div className='cursor-default select-none rounded-lg bg-background-section p-3 text-center text-xs text-text-tertiary'>
|
||||
{t('appDebug.datasetConfig.knowledgeTip')}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(DatasetList)
|
||||
@@ -0,0 +1,95 @@
|
||||
import {
|
||||
useCallback,
|
||||
useMemo,
|
||||
useState,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
RiAddLine,
|
||||
} from '@remixicon/react'
|
||||
import MetadataIcon from './metadata-icon'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Input from '@/app/components/base/input'
|
||||
import type { MetadataShape } from '@/app/components/workflow/nodes/knowledge-retrieval/types'
|
||||
import type { MetadataInDoc } from '@/models/datasets'
|
||||
|
||||
const AddCondition = ({
|
||||
metadataList,
|
||||
handleAddCondition,
|
||||
}: Pick<MetadataShape, 'handleAddCondition' | 'metadataList'>) => {
|
||||
const { t } = useTranslation()
|
||||
const [open, setOpen] = useState(false)
|
||||
const [searchText, setSearchText] = useState('')
|
||||
|
||||
const filteredMetadataList = useMemo(() => {
|
||||
return metadataList?.filter(metadata => metadata.name.includes(searchText))
|
||||
}, [metadataList, searchText])
|
||||
|
||||
const handleAddConditionWrapped = useCallback((item: MetadataInDoc) => {
|
||||
handleAddCondition?.(item)
|
||||
setOpen(false)
|
||||
}, [handleAddCondition])
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement='bottom-start'
|
||||
offset={{
|
||||
mainAxis: 3,
|
||||
crossAxis: 0,
|
||||
}}
|
||||
>
|
||||
<PortalToFollowElemTrigger onClick={() => setOpen(!open)}>
|
||||
<Button
|
||||
size='small'
|
||||
variant='secondary'
|
||||
>
|
||||
<RiAddLine className='h-3.5 w-3.5' />
|
||||
{t('workflow.nodes.knowledgeRetrieval.metadata.panel.add')}
|
||||
</Button>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className='z-10'>
|
||||
<div className='w-[320px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg'>
|
||||
<div className='p-2 pb-1'>
|
||||
<Input
|
||||
showLeftIcon
|
||||
placeholder={t('workflow.nodes.knowledgeRetrieval.metadata.panel.search')}
|
||||
value={searchText}
|
||||
onChange={e => setSearchText(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<div className='p-1'>
|
||||
{
|
||||
filteredMetadataList?.map(metadata => (
|
||||
<div
|
||||
key={metadata.name}
|
||||
className='system-sm-medium flex h-6 cursor-pointer items-center rounded-md px-3 text-text-secondary hover:bg-state-base-hover'
|
||||
>
|
||||
<div className='mr-1 p-[1px]'>
|
||||
<MetadataIcon type={metadata.type} />
|
||||
</div>
|
||||
<div
|
||||
className='grow truncate'
|
||||
title={metadata.name}
|
||||
onClick={() => handleAddConditionWrapped(metadata)}
|
||||
>
|
||||
{metadata.name}
|
||||
</div>
|
||||
<div className='system-xs-regular shrink-0 text-text-tertiary'>{metadata.type}</div>
|
||||
</div>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
)
|
||||
}
|
||||
|
||||
export default AddCondition
|
||||
@@ -0,0 +1,91 @@
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import type { VarType } from '@/app/components/workflow/types'
|
||||
import { Variable02 } from '@/app/components/base/icons/src/vender/solid/development'
|
||||
|
||||
type ConditionCommonVariableSelectorProps = {
|
||||
variables?: { name: string; type: string; value: string }[]
|
||||
value?: string | number
|
||||
varType?: VarType
|
||||
onChange: (v: string) => void
|
||||
}
|
||||
|
||||
const ConditionCommonVariableSelector = ({
|
||||
variables = [],
|
||||
value,
|
||||
onChange,
|
||||
varType,
|
||||
}: ConditionCommonVariableSelectorProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [open, setOpen] = useState(false)
|
||||
|
||||
const selected = variables.find(v => v.value === value)
|
||||
const handleChange = useCallback((v: string) => {
|
||||
onChange(v)
|
||||
setOpen(false)
|
||||
}, [onChange])
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement='bottom-start'
|
||||
offset={{
|
||||
mainAxis: 4,
|
||||
crossAxis: 0,
|
||||
}}
|
||||
>
|
||||
<PortalToFollowElemTrigger asChild onClick={() => {
|
||||
if (!variables.length) return
|
||||
setOpen(!open)
|
||||
}}>
|
||||
<div className="flex h-6 grow cursor-pointer items-center">
|
||||
{
|
||||
selected && (
|
||||
<div className='system-xs-medium inline-flex h-6 items-center rounded-md border-[0.5px] border-components-panel-border-subtle bg-components-badge-white-to-dark pl-[5px] pr-1.5 text-text-secondary shadow-xs'>
|
||||
<Variable02 className='mr-1 h-3.5 w-3.5 text-text-accent' />
|
||||
{selected.value}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
!selected && (
|
||||
<>
|
||||
<div className='system-sm-regular flex grow items-center text-components-input-text-placeholder'>
|
||||
<Variable02 className='mr-1 h-4 w-4' />
|
||||
{t('workflow.nodes.knowledgeRetrieval.metadata.panel.select')}
|
||||
</div>
|
||||
<div className='system-2xs-medium flex h-5 shrink-0 items-center rounded-[5px] border border-divider-deep px-[5px] text-text-tertiary'>
|
||||
{varType}
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className='z-[1000]'>
|
||||
<div className='w-[200px] rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg'>
|
||||
{
|
||||
variables.map(v => (
|
||||
<div
|
||||
key={v.value}
|
||||
className='system-xs-medium flex h-6 cursor-pointer items-center rounded-md px-2 text-text-secondary hover:bg-state-base-hover'
|
||||
onClick={() => handleChange(v.value)}
|
||||
>
|
||||
<Variable02 className='mr-1 h-4 w-4 text-text-accent' />
|
||||
{v.value}
|
||||
</div>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
)
|
||||
}
|
||||
|
||||
export default ConditionCommonVariableSelector
|
||||
@@ -0,0 +1,86 @@
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import dayjs from 'dayjs'
|
||||
import {
|
||||
RiCalendarLine,
|
||||
RiCloseCircleFill,
|
||||
} from '@remixicon/react'
|
||||
import DatePicker from '@/app/components/base/date-and-time-picker/date-picker'
|
||||
import type { TriggerProps } from '@/app/components/base/date-and-time-picker/types'
|
||||
import cn from '@/utils/classnames'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
|
||||
type ConditionDateProps = {
|
||||
value?: number
|
||||
onChange: (date?: number) => void
|
||||
}
|
||||
const ConditionDate = ({
|
||||
value,
|
||||
onChange,
|
||||
}: ConditionDateProps) => {
|
||||
const { t } = useTranslation()
|
||||
const { userProfile: { timezone } } = useAppContext()
|
||||
|
||||
const handleDateChange = useCallback((date?: dayjs.Dayjs) => {
|
||||
if (date)
|
||||
onChange(date.unix())
|
||||
else
|
||||
onChange()
|
||||
}, [onChange])
|
||||
|
||||
const renderTrigger = useCallback(({
|
||||
handleClickTrigger,
|
||||
}: TriggerProps) => {
|
||||
return (
|
||||
<div className='group flex items-center' onClick={handleClickTrigger}>
|
||||
<div
|
||||
className={cn(
|
||||
'system-sm-regular mr-0.5 flex h-6 grow cursor-pointer items-center px-1',
|
||||
value ? 'text-text-secondary' : 'text-text-tertiary',
|
||||
)}
|
||||
>
|
||||
{
|
||||
value
|
||||
? dayjs(value * 1000).tz(timezone).format('MMMM DD YYYY HH:mm A')
|
||||
: t('workflow.nodes.knowledgeRetrieval.metadata.panel.datePlaceholder')
|
||||
}
|
||||
</div>
|
||||
{
|
||||
value && (
|
||||
<RiCloseCircleFill
|
||||
className={cn(
|
||||
'hidden h-4 w-4 shrink-0 cursor-pointer hover:text-components-input-text-filled group-hover:block',
|
||||
value && 'text-text-quaternary',
|
||||
)}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
handleDateChange()
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
<RiCalendarLine
|
||||
className={cn(
|
||||
'block h-4 w-4 shrink-0',
|
||||
value ? 'text-text-quaternary' : 'text-text-tertiary',
|
||||
value && 'group-hover:hidden',
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}, [value, handleDateChange, timezone, t])
|
||||
|
||||
return (
|
||||
<div className='h-8 px-2 py-1'>
|
||||
<DatePicker
|
||||
timezone={timezone}
|
||||
value={value ? dayjs(value * 1000) : undefined}
|
||||
onChange={handleDateChange}
|
||||
onClear={handleDateChange}
|
||||
renderTrigger={renderTrigger}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ConditionDate
|
||||
@@ -0,0 +1,196 @@
|
||||
import {
|
||||
useCallback,
|
||||
useMemo,
|
||||
useState,
|
||||
} from 'react'
|
||||
import {
|
||||
RiDeleteBinLine,
|
||||
} from '@remixicon/react'
|
||||
import MetadataIcon from '../metadata-icon'
|
||||
import {
|
||||
COMMON_VARIABLE_REGEX,
|
||||
VARIABLE_REGEX,
|
||||
comparisonOperatorNotRequireValue,
|
||||
} from './utils'
|
||||
import ConditionOperator from './condition-operator'
|
||||
import ConditionString from './condition-string'
|
||||
import ConditionNumber from './condition-number'
|
||||
import ConditionDate from './condition-date'
|
||||
import type {
|
||||
ComparisonOperator,
|
||||
HandleRemoveCondition,
|
||||
HandleUpdateCondition,
|
||||
MetadataFilteringCondition,
|
||||
MetadataShape,
|
||||
} from '@/app/components/workflow/nodes/knowledge-retrieval/types'
|
||||
import { MetadataFilteringVariableType } from '@/app/components/workflow/nodes/knowledge-retrieval/types'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
type ConditionItemProps = {
|
||||
className?: string
|
||||
disabled?: boolean
|
||||
condition: MetadataFilteringCondition // condition may the condition of case or condition of sub variable
|
||||
onRemoveCondition?: HandleRemoveCondition
|
||||
onUpdateCondition?: HandleUpdateCondition
|
||||
} & Pick<MetadataShape, 'metadataList' | 'availableStringVars' | 'availableStringNodesWithParent' | 'availableNumberVars' | 'availableNumberNodesWithParent' | 'isCommonVariable' | 'availableCommonStringVars' | 'availableCommonNumberVars'>
|
||||
const ConditionItem = ({
|
||||
className,
|
||||
disabled,
|
||||
condition,
|
||||
onRemoveCondition,
|
||||
onUpdateCondition,
|
||||
metadataList = [],
|
||||
availableStringVars = [],
|
||||
availableStringNodesWithParent = [],
|
||||
availableNumberVars = [],
|
||||
availableNumberNodesWithParent = [],
|
||||
isCommonVariable,
|
||||
availableCommonStringVars = [],
|
||||
availableCommonNumberVars = [],
|
||||
}: ConditionItemProps) => {
|
||||
const [isHovered, setIsHovered] = useState(false)
|
||||
|
||||
const canChooseOperator = useMemo(() => {
|
||||
if (disabled)
|
||||
return false
|
||||
|
||||
return true
|
||||
}, [disabled])
|
||||
|
||||
const doRemoveCondition = useCallback(() => {
|
||||
onRemoveCondition?.(condition.id)
|
||||
}, [onRemoveCondition, condition.id])
|
||||
|
||||
const currentMetadata = useMemo(() => {
|
||||
return metadataList.find(metadata => metadata.name === condition.name)
|
||||
}, [metadataList, condition.name])
|
||||
|
||||
const handleConditionOperatorChange = useCallback((operator: ComparisonOperator) => {
|
||||
onUpdateCondition?.(
|
||||
condition.id,
|
||||
{
|
||||
...condition,
|
||||
value: comparisonOperatorNotRequireValue(condition.comparison_operator) ? undefined : condition.value,
|
||||
comparison_operator: operator,
|
||||
})
|
||||
}, [onUpdateCondition, condition])
|
||||
|
||||
const valueAndValueMethod = useMemo(() => {
|
||||
if (
|
||||
(currentMetadata?.type === MetadataFilteringVariableType.string
|
||||
|| currentMetadata?.type === MetadataFilteringVariableType.number
|
||||
|| currentMetadata?.type === MetadataFilteringVariableType.select)
|
||||
&& typeof condition.value === 'string'
|
||||
) {
|
||||
const regex = isCommonVariable ? COMMON_VARIABLE_REGEX : VARIABLE_REGEX
|
||||
const matchedStartNumber = isCommonVariable ? 2 : 3
|
||||
const matched = condition.value.match(regex)
|
||||
|
||||
if (matched?.length) {
|
||||
return {
|
||||
value: matched[0].slice(matchedStartNumber, -matchedStartNumber),
|
||||
valueMethod: 'variable',
|
||||
}
|
||||
}
|
||||
else {
|
||||
return {
|
||||
value: condition.value,
|
||||
valueMethod: 'constant',
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
value: condition.value,
|
||||
valueMethod: 'constant',
|
||||
}
|
||||
}, [currentMetadata, condition.value, isCommonVariable])
|
||||
const [localValueMethod, setLocalValueMethod] = useState(valueAndValueMethod.valueMethod)
|
||||
|
||||
const handleValueMethodChange = useCallback((v: string) => {
|
||||
setLocalValueMethod(v)
|
||||
onUpdateCondition?.(condition.id, { ...condition, value: undefined })
|
||||
}, [condition, onUpdateCondition])
|
||||
|
||||
const handleValueChange = useCallback((v: any) => {
|
||||
onUpdateCondition?.(condition.id, { ...condition, value: v })
|
||||
}, [condition, onUpdateCondition])
|
||||
|
||||
return (
|
||||
<div className={cn('mb-1 flex last-of-type:mb-0', className)}>
|
||||
<div className={cn(
|
||||
'grow rounded-lg bg-components-input-bg-normal',
|
||||
isHovered && 'bg-state-destructive-hover',
|
||||
)}>
|
||||
<div className='flex items-center p-1'>
|
||||
<div className='w-0 grow'>
|
||||
<div className='flex h-6 min-w-0 items-center rounded-md border-[0.5px] border-components-panel-border-subtle bg-components-badge-white-to-dark pl-1 pr-1.5 shadow-xs'>
|
||||
<div className='mr-0.5 p-[1px]'>
|
||||
<MetadataIcon type={currentMetadata?.type} className='h-3 w-3' />
|
||||
</div>
|
||||
<div className='system-xs-medium mr-0.5 min-w-0 flex-1 truncate text-text-secondary'>{currentMetadata?.name}</div>
|
||||
<div className='system-xs-regular text-text-tertiary'>{currentMetadata?.type}</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className='mx-1 h-3 w-[1px] bg-divider-regular'></div>
|
||||
<ConditionOperator
|
||||
disabled={!canChooseOperator}
|
||||
variableType={currentMetadata?.type || MetadataFilteringVariableType.string}
|
||||
value={condition.comparison_operator}
|
||||
onSelect={handleConditionOperatorChange}
|
||||
/>
|
||||
</div>
|
||||
<div className='border-t border-t-divider-subtle'>
|
||||
{
|
||||
!comparisonOperatorNotRequireValue(condition.comparison_operator)
|
||||
&& (currentMetadata?.type === MetadataFilteringVariableType.string
|
||||
|| currentMetadata?.type === MetadataFilteringVariableType.select) && (
|
||||
<ConditionString
|
||||
valueMethod={localValueMethod}
|
||||
onValueMethodChange={handleValueMethodChange}
|
||||
nodesOutputVars={availableStringVars}
|
||||
availableNodes={availableStringNodesWithParent}
|
||||
value={valueAndValueMethod.value as string}
|
||||
onChange={handleValueChange}
|
||||
isCommonVariable={isCommonVariable}
|
||||
commonVariables={availableCommonStringVars}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
!comparisonOperatorNotRequireValue(condition.comparison_operator) && currentMetadata?.type === MetadataFilteringVariableType.number && (
|
||||
<ConditionNumber
|
||||
valueMethod={localValueMethod}
|
||||
onValueMethodChange={handleValueMethodChange}
|
||||
nodesOutputVars={availableNumberVars}
|
||||
availableNodes={availableNumberNodesWithParent}
|
||||
value={valueAndValueMethod.value}
|
||||
onChange={handleValueChange}
|
||||
isCommonVariable={isCommonVariable}
|
||||
commonVariables={availableCommonNumberVars}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
!comparisonOperatorNotRequireValue(condition.comparison_operator) && currentMetadata?.type === MetadataFilteringVariableType.time && (
|
||||
<ConditionDate
|
||||
value={condition.value as number}
|
||||
onChange={handleValueChange}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
className='ml-1 mt-1 flex h-6 w-6 shrink-0 cursor-pointer items-center justify-center rounded-lg text-text-tertiary hover:bg-state-destructive-hover hover:text-text-destructive'
|
||||
onMouseEnter={() => setIsHovered(true)}
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
onClick={doRemoveCondition}
|
||||
>
|
||||
<RiDeleteBinLine className='h-4 w-4' />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ConditionItem
|
||||
@@ -0,0 +1,88 @@
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ConditionValueMethod from './condition-value-method'
|
||||
import type { ConditionValueMethodProps } from './condition-value-method'
|
||||
import ConditionVariableSelector from './condition-variable-selector'
|
||||
import ConditionCommonVariableSelector from './condition-common-variable-selector'
|
||||
import type {
|
||||
Node,
|
||||
NodeOutPutVar,
|
||||
ValueSelector,
|
||||
} from '@/app/components/workflow/types'
|
||||
import { VarType } from '@/app/components/workflow/types'
|
||||
import Input from '@/app/components/base/input'
|
||||
|
||||
type ConditionNumberProps = {
|
||||
value?: string | number
|
||||
onChange: (value?: string | number) => void
|
||||
nodesOutputVars: NodeOutPutVar[]
|
||||
availableNodes: Node[]
|
||||
isCommonVariable?: boolean
|
||||
commonVariables: { name: string; type: string; value: string }[]
|
||||
} & ConditionValueMethodProps
|
||||
const ConditionNumber = ({
|
||||
value,
|
||||
onChange,
|
||||
valueMethod,
|
||||
onValueMethodChange,
|
||||
nodesOutputVars,
|
||||
availableNodes,
|
||||
isCommonVariable,
|
||||
commonVariables,
|
||||
}: ConditionNumberProps) => {
|
||||
const { t } = useTranslation()
|
||||
const handleVariableValueChange = useCallback((v: ValueSelector) => {
|
||||
onChange(`{{#${v.join('.')}#}}`)
|
||||
}, [onChange])
|
||||
|
||||
const handleCommonVariableValueChange = useCallback((v: string) => {
|
||||
onChange(`{{${v}}}`)
|
||||
}, [onChange])
|
||||
|
||||
return (
|
||||
<div className='flex h-8 items-center pl-1 pr-2'>
|
||||
<ConditionValueMethod
|
||||
valueMethod={valueMethod}
|
||||
onValueMethodChange={onValueMethodChange}
|
||||
/>
|
||||
<div className='ml-1 mr-1.5 h-4 w-[1px] bg-divider-regular'></div>
|
||||
{
|
||||
valueMethod === 'variable' && !isCommonVariable && (
|
||||
<ConditionVariableSelector
|
||||
valueSelector={value ? (value as string).split('.') : []}
|
||||
onChange={handleVariableValueChange}
|
||||
nodesOutputVars={nodesOutputVars}
|
||||
availableNodes={availableNodes}
|
||||
varType={VarType.number}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
valueMethod === 'variable' && isCommonVariable && (
|
||||
<ConditionCommonVariableSelector
|
||||
variables={commonVariables}
|
||||
value={value}
|
||||
onChange={handleCommonVariableValueChange}
|
||||
varType={VarType.number}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
valueMethod === 'constant' && (
|
||||
<Input
|
||||
className='border-none bg-transparent outline-none hover:bg-transparent focus:bg-transparent focus:shadow-none'
|
||||
value={value}
|
||||
onChange={(e) => {
|
||||
const v = e.target.value
|
||||
onChange(v ? Number(e.target.value) : undefined)
|
||||
}}
|
||||
placeholder={t('workflow.nodes.knowledgeRetrieval.metadata.panel.placeholder')}
|
||||
type='number'
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ConditionNumber
|
||||
@@ -0,0 +1,98 @@
|
||||
import {
|
||||
useMemo,
|
||||
useState,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiArrowDownSLine } from '@remixicon/react'
|
||||
import {
|
||||
getOperators,
|
||||
isComparisonOperatorNeedTranslate,
|
||||
} from './utils'
|
||||
import Button from '@/app/components/base/button'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import cn from '@/utils/classnames'
|
||||
import type {
|
||||
ComparisonOperator,
|
||||
MetadataFilteringVariableType,
|
||||
} from '@/app/components/workflow/nodes/knowledge-retrieval/types'
|
||||
|
||||
const i18nPrefix = 'workflow.nodes.ifElse'
|
||||
|
||||
type ConditionOperatorProps = {
|
||||
className?: string
|
||||
disabled?: boolean
|
||||
variableType: MetadataFilteringVariableType
|
||||
value?: string
|
||||
onSelect: (value: ComparisonOperator) => void
|
||||
}
|
||||
const ConditionOperator = ({
|
||||
className,
|
||||
disabled,
|
||||
variableType,
|
||||
value,
|
||||
onSelect,
|
||||
}: ConditionOperatorProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [open, setOpen] = useState(false)
|
||||
|
||||
const options = useMemo(() => {
|
||||
return getOperators(variableType).map((o) => {
|
||||
return {
|
||||
label: isComparisonOperatorNeedTranslate(o) ? t(`${i18nPrefix}.comparisonOperator.${o}`) : o,
|
||||
value: o,
|
||||
}
|
||||
})
|
||||
}, [t, variableType])
|
||||
const selectedOption = options.find(o => Array.isArray(value) ? o.value === value[0] : o.value === value)
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement='bottom-end'
|
||||
offset={{
|
||||
mainAxis: 4,
|
||||
crossAxis: 0,
|
||||
}}
|
||||
>
|
||||
<PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
|
||||
<Button
|
||||
className={cn('shrink-0', !selectedOption && 'opacity-50', className)}
|
||||
size='small'
|
||||
variant='ghost'
|
||||
disabled={disabled}
|
||||
>
|
||||
{
|
||||
selectedOption
|
||||
? selectedOption.label
|
||||
: t(`${i18nPrefix}.select`)
|
||||
}
|
||||
<RiArrowDownSLine className='ml-1 h-3.5 w-3.5' />
|
||||
</Button>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className='z-10'>
|
||||
<div className='rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg'>
|
||||
{
|
||||
options.map(option => (
|
||||
<div
|
||||
key={option.value}
|
||||
className='flex h-7 cursor-pointer items-center rounded-lg px-3 py-1.5 text-[13px] font-medium text-text-secondary hover:bg-state-base-hover'
|
||||
onClick={() => {
|
||||
onSelect(option.value)
|
||||
setOpen(false)
|
||||
}}
|
||||
>
|
||||
{option.label}
|
||||
</div>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
)
|
||||
}
|
||||
|
||||
export default ConditionOperator
|
||||
@@ -0,0 +1,84 @@
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ConditionValueMethod from './condition-value-method'
|
||||
import type { ConditionValueMethodProps } from './condition-value-method'
|
||||
import ConditionVariableSelector from './condition-variable-selector'
|
||||
import ConditionCommonVariableSelector from './condition-common-variable-selector'
|
||||
import type {
|
||||
Node,
|
||||
NodeOutPutVar,
|
||||
ValueSelector,
|
||||
} from '@/app/components/workflow/types'
|
||||
import Input from '@/app/components/base/input'
|
||||
import { VarType } from '@/app/components/workflow/types'
|
||||
|
||||
type ConditionStringProps = {
|
||||
value?: string
|
||||
onChange: (value: string) => void
|
||||
nodesOutputVars: NodeOutPutVar[]
|
||||
availableNodes: Node[]
|
||||
isCommonVariable?: boolean
|
||||
commonVariables: { name: string; type: string; value: string }[]
|
||||
} & ConditionValueMethodProps
|
||||
const ConditionString = ({
|
||||
value,
|
||||
onChange,
|
||||
valueMethod = 'constant',
|
||||
onValueMethodChange,
|
||||
nodesOutputVars,
|
||||
availableNodes,
|
||||
isCommonVariable,
|
||||
commonVariables,
|
||||
}: ConditionStringProps) => {
|
||||
const { t } = useTranslation()
|
||||
const handleVariableValueChange = useCallback((v: ValueSelector) => {
|
||||
onChange(`{{#${v.join('.')}#}}`)
|
||||
}, [onChange])
|
||||
|
||||
const handleCommonVariableValueChange = useCallback((v: string) => {
|
||||
onChange(`{{${v}}}`)
|
||||
}, [onChange])
|
||||
|
||||
return (
|
||||
<div className='flex h-8 items-center pl-1 pr-2'>
|
||||
<ConditionValueMethod
|
||||
valueMethod={valueMethod}
|
||||
onValueMethodChange={onValueMethodChange}
|
||||
/>
|
||||
<div className='ml-1 mr-1.5 h-4 w-[1px] bg-divider-regular'></div>
|
||||
{
|
||||
valueMethod === 'variable' && !isCommonVariable && (
|
||||
<ConditionVariableSelector
|
||||
valueSelector={value ? value!.split('.') : []}
|
||||
onChange={handleVariableValueChange}
|
||||
nodesOutputVars={nodesOutputVars}
|
||||
availableNodes={availableNodes}
|
||||
varType={VarType.string}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
valueMethod === 'variable' && isCommonVariable && (
|
||||
<ConditionCommonVariableSelector
|
||||
variables={commonVariables}
|
||||
value={value}
|
||||
onChange={handleCommonVariableValueChange}
|
||||
varType={VarType.string}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
valueMethod === 'constant' && (
|
||||
<Input
|
||||
className='border-none bg-transparent outline-none hover:bg-transparent focus:bg-transparent focus:shadow-none'
|
||||
value={value}
|
||||
onChange={e => onChange(e.target.value)}
|
||||
placeholder={t('workflow.nodes.knowledgeRetrieval.metadata.panel.placeholder')}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ConditionString
|
||||
@@ -0,0 +1,71 @@
|
||||
import { useState } from 'react'
|
||||
import { capitalize } from 'lodash-es'
|
||||
import { RiArrowDownSLine } from '@remixicon/react'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import Button from '@/app/components/base/button'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
export type ConditionValueMethodProps = {
|
||||
valueMethod?: string
|
||||
onValueMethodChange: (v: string) => void
|
||||
}
|
||||
const options = [
|
||||
'variable',
|
||||
'constant',
|
||||
]
|
||||
const ConditionValueMethod = ({
|
||||
valueMethod = 'variable',
|
||||
onValueMethodChange,
|
||||
}: ConditionValueMethodProps) => {
|
||||
const [open, setOpen] = useState(false)
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement='bottom-start'
|
||||
offset={{ mainAxis: 4, crossAxis: 0 }}
|
||||
>
|
||||
<PortalToFollowElemTrigger asChild onClick={() => setOpen(v => !v)}>
|
||||
<Button
|
||||
className='shrink-0'
|
||||
variant='ghost'
|
||||
size='small'
|
||||
>
|
||||
{capitalize(valueMethod)}
|
||||
<RiArrowDownSLine className='ml-[1px] h-3.5 w-3.5' />
|
||||
</Button>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className='z-[1000]'>
|
||||
<div className='w-[112px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg'>
|
||||
{
|
||||
options.map(option => (
|
||||
<div
|
||||
key={option}
|
||||
className={cn(
|
||||
'flex h-7 cursor-pointer items-center rounded-md px-3 hover:bg-state-base-hover',
|
||||
'text-[13px] font-medium text-text-secondary',
|
||||
valueMethod === option && 'bg-state-base-hover',
|
||||
)}
|
||||
onClick={() => {
|
||||
if (valueMethod === option)
|
||||
return
|
||||
onValueMethodChange(option)
|
||||
setOpen(false)
|
||||
}}
|
||||
>
|
||||
{capitalize(option)}
|
||||
</div>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
)
|
||||
}
|
||||
|
||||
export default ConditionValueMethod
|
||||
@@ -0,0 +1,92 @@
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import VariableTag from '@/app/components/workflow/nodes/_base/components/variable-tag'
|
||||
import VarReferenceVars from '@/app/components/workflow/nodes/_base/components/variable/var-reference-vars'
|
||||
import type {
|
||||
Node,
|
||||
NodeOutPutVar,
|
||||
ValueSelector,
|
||||
Var,
|
||||
} from '@/app/components/workflow/types'
|
||||
import { VarType } from '@/app/components/workflow/types'
|
||||
import { Variable02 } from '@/app/components/base/icons/src/vender/solid/development'
|
||||
|
||||
type ConditionVariableSelectorProps = {
|
||||
valueSelector?: ValueSelector
|
||||
varType?: VarType
|
||||
availableNodes?: Node[]
|
||||
nodesOutputVars?: NodeOutPutVar[]
|
||||
onChange: (valueSelector: ValueSelector, varItem: Var) => void
|
||||
}
|
||||
|
||||
const ConditionVariableSelector = ({
|
||||
valueSelector = [],
|
||||
varType = VarType.string,
|
||||
availableNodes = [],
|
||||
nodesOutputVars = [],
|
||||
onChange,
|
||||
}: ConditionVariableSelectorProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [open, setOpen] = useState(false)
|
||||
|
||||
const handleChange = useCallback((valueSelector: ValueSelector, varItem: Var) => {
|
||||
onChange(valueSelector, varItem)
|
||||
setOpen(false)
|
||||
}, [onChange])
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement='bottom-start'
|
||||
offset={{
|
||||
mainAxis: 4,
|
||||
crossAxis: 0,
|
||||
}}
|
||||
>
|
||||
<PortalToFollowElemTrigger asChild onClick={() => setOpen(!open)}>
|
||||
<div className="flex h-6 grow cursor-pointer items-center">
|
||||
{
|
||||
!!valueSelector.length && (
|
||||
<VariableTag
|
||||
valueSelector={valueSelector}
|
||||
varType={varType}
|
||||
availableNodes={availableNodes}
|
||||
isShort
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
!valueSelector.length && (
|
||||
<>
|
||||
<div className='system-sm-regular flex grow items-center text-components-input-text-placeholder'>
|
||||
<Variable02 className='mr-1 h-4 w-4' />
|
||||
{t('workflow.nodes.knowledgeRetrieval.metadata.panel.select')}
|
||||
</div>
|
||||
<div className='system-2xs-medium flex h-5 shrink-0 items-center rounded-[5px] border border-divider-deep px-[5px] text-text-tertiary'>
|
||||
{varType}
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className='z-[1000]'>
|
||||
<div className='w-[296px] rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg'>
|
||||
<VarReferenceVars
|
||||
vars={nodesOutputVars}
|
||||
isSupportFileVar
|
||||
onChange={handleChange}
|
||||
/>
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
)
|
||||
}
|
||||
|
||||
export default ConditionVariableSelector
|
||||
@@ -0,0 +1,75 @@
|
||||
import { RiLoopLeftLine } from '@remixicon/react'
|
||||
import ConditionItem from './condition-item'
|
||||
import cn from '@/utils/classnames'
|
||||
import type { MetadataShape } from '@/app/components/workflow/nodes/knowledge-retrieval/types'
|
||||
import { LogicalOperator } from '@/app/components/workflow/nodes/knowledge-retrieval/types'
|
||||
|
||||
type ConditionListProps = {
|
||||
disabled?: boolean
|
||||
} & Omit<MetadataShape, 'handleAddCondition'>
|
||||
|
||||
const ConditionList = ({
|
||||
disabled,
|
||||
metadataList = [],
|
||||
metadataFilteringConditions = {
|
||||
conditions: [],
|
||||
logical_operator: LogicalOperator.and,
|
||||
},
|
||||
handleRemoveCondition,
|
||||
handleToggleConditionLogicalOperator,
|
||||
handleUpdateCondition,
|
||||
availableStringVars,
|
||||
availableStringNodesWithParent,
|
||||
availableNumberVars,
|
||||
availableNumberNodesWithParent,
|
||||
isCommonVariable,
|
||||
availableCommonNumberVars,
|
||||
availableCommonStringVars,
|
||||
}: ConditionListProps) => {
|
||||
const { conditions, logical_operator } = metadataFilteringConditions
|
||||
|
||||
return (
|
||||
<div className={cn('relative')}>
|
||||
{
|
||||
conditions.length > 1 && (
|
||||
<div className={cn(
|
||||
'absolute bottom-0 left-0 top-0 w-[44px]',
|
||||
)}>
|
||||
<div className='absolute bottom-4 right-1 top-4 w-2.5 rounded-l-[8px] border border-r-0 border-divider-deep'></div>
|
||||
<div className='absolute right-0 top-1/2 h-[29px] w-4 -translate-y-1/2 bg-components-panel-bg'></div>
|
||||
<div
|
||||
className='absolute right-1 top-1/2 flex h-[21px] -translate-y-1/2 cursor-pointer select-none items-center rounded-md border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-1 text-[10px] font-semibold text-text-accent-secondary shadow-xs'
|
||||
onClick={() => handleToggleConditionLogicalOperator()}
|
||||
>
|
||||
{logical_operator.toUpperCase()}
|
||||
<RiLoopLeftLine className='ml-0.5 h-3 w-3' />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<div className={cn(conditions.length > 1 && 'pl-[44px]')}>
|
||||
{
|
||||
conditions.map(condition => (
|
||||
<ConditionItem
|
||||
key={`${condition.id}`}
|
||||
disabled={disabled}
|
||||
condition={condition}
|
||||
onUpdateCondition={handleUpdateCondition}
|
||||
onRemoveCondition={handleRemoveCondition}
|
||||
metadataList={metadataList}
|
||||
availableStringVars={availableStringVars}
|
||||
availableStringNodesWithParent={availableStringNodesWithParent}
|
||||
availableNumberVars={availableNumberVars}
|
||||
availableNumberNodesWithParent={availableNumberNodesWithParent}
|
||||
isCommonVariable={isCommonVariable}
|
||||
availableCommonStringVars={availableCommonStringVars}
|
||||
availableCommonNumberVars={availableCommonNumberVars}
|
||||
/>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ConditionList
|
||||
@@ -0,0 +1,68 @@
|
||||
import {
|
||||
ComparisonOperator,
|
||||
MetadataFilteringVariableType,
|
||||
} from '@/app/components/workflow/nodes/knowledge-retrieval/types'
|
||||
|
||||
export const isEmptyRelatedOperator = (operator: ComparisonOperator) => {
|
||||
return [ComparisonOperator.empty, ComparisonOperator.notEmpty, ComparisonOperator.isNull, ComparisonOperator.isNotNull, ComparisonOperator.exists, ComparisonOperator.notExists].includes(operator)
|
||||
}
|
||||
|
||||
const notTranslateKey = [
|
||||
ComparisonOperator.equal, ComparisonOperator.notEqual,
|
||||
ComparisonOperator.largerThan, ComparisonOperator.largerThanOrEqual,
|
||||
ComparisonOperator.lessThan, ComparisonOperator.lessThanOrEqual,
|
||||
]
|
||||
|
||||
export const isComparisonOperatorNeedTranslate = (operator?: ComparisonOperator) => {
|
||||
if (!operator)
|
||||
return false
|
||||
return !notTranslateKey.includes(operator)
|
||||
}
|
||||
|
||||
export const getOperators = (type?: MetadataFilteringVariableType) => {
|
||||
switch (type) {
|
||||
case MetadataFilteringVariableType.string:
|
||||
case MetadataFilteringVariableType.select:
|
||||
return [
|
||||
ComparisonOperator.is,
|
||||
ComparisonOperator.isNot,
|
||||
ComparisonOperator.contains,
|
||||
ComparisonOperator.notContains,
|
||||
ComparisonOperator.startWith,
|
||||
ComparisonOperator.endWith,
|
||||
ComparisonOperator.empty,
|
||||
ComparisonOperator.notEmpty,
|
||||
ComparisonOperator.in,
|
||||
ComparisonOperator.notIn,
|
||||
]
|
||||
case MetadataFilteringVariableType.number:
|
||||
return [
|
||||
ComparisonOperator.equal,
|
||||
ComparisonOperator.notEqual,
|
||||
ComparisonOperator.largerThan,
|
||||
ComparisonOperator.lessThan,
|
||||
ComparisonOperator.largerThanOrEqual,
|
||||
ComparisonOperator.lessThanOrEqual,
|
||||
ComparisonOperator.empty,
|
||||
ComparisonOperator.notEmpty,
|
||||
]
|
||||
default:
|
||||
return [
|
||||
ComparisonOperator.is,
|
||||
ComparisonOperator.before,
|
||||
ComparisonOperator.after,
|
||||
ComparisonOperator.empty,
|
||||
ComparisonOperator.notEmpty,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
export const comparisonOperatorNotRequireValue = (operator?: ComparisonOperator) => {
|
||||
if (!operator)
|
||||
return false
|
||||
|
||||
return [ComparisonOperator.empty, ComparisonOperator.notEmpty, ComparisonOperator.isNull, ComparisonOperator.isNotNull, ComparisonOperator.exists, ComparisonOperator.notExists].includes(operator)
|
||||
}
|
||||
|
||||
export const VARIABLE_REGEX = /\{\{(#[a-zA-Z0-9_-]{1,50}(\.[a-zA-Z_]\w{0,29}){1,10}#)\}\}/gi
|
||||
export const COMMON_VARIABLE_REGEX = /\{\{([a-zA-Z0-9_-]{1,50})\}\}/gi
|
||||
@@ -0,0 +1,104 @@
|
||||
import {
|
||||
useCallback,
|
||||
useState,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import MetadataTrigger from '../metadata-trigger'
|
||||
import MetadataFilterSelector from './metadata-filter-selector'
|
||||
import Collapse from '@/app/components/workflow/nodes/_base/components/collapse'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import type { MetadataShape } from '@/app/components/workflow/nodes/knowledge-retrieval/types'
|
||||
import { MetadataFilteringModeEnum } from '@/app/components/workflow/nodes/knowledge-retrieval/types'
|
||||
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
|
||||
import { noop } from 'lodash-es'
|
||||
|
||||
type MetadataFilterProps = {
|
||||
metadataFilterMode?: MetadataFilteringModeEnum
|
||||
handleMetadataFilterModeChange: (mode: MetadataFilteringModeEnum) => void
|
||||
} & MetadataShape
|
||||
const MetadataFilter = ({
|
||||
metadataFilterMode = MetadataFilteringModeEnum.disabled,
|
||||
handleMetadataFilterModeChange,
|
||||
metadataModelConfig,
|
||||
handleMetadataModelChange,
|
||||
handleMetadataCompletionParamsChange,
|
||||
...restProps
|
||||
}: MetadataFilterProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [collapsed, setCollapsed] = useState(true)
|
||||
|
||||
const handleMetadataFilterModeChangeWrapped = useCallback((mode: MetadataFilteringModeEnum) => {
|
||||
if (mode === MetadataFilteringModeEnum.automatic)
|
||||
setCollapsed(false)
|
||||
|
||||
handleMetadataFilterModeChange(mode)
|
||||
}, [handleMetadataFilterModeChange])
|
||||
|
||||
return (
|
||||
<Collapse
|
||||
disabled={metadataFilterMode === MetadataFilteringModeEnum.disabled || metadataFilterMode === MetadataFilteringModeEnum.manual}
|
||||
collapsed={collapsed}
|
||||
onCollapse={setCollapsed}
|
||||
hideCollapseIcon
|
||||
trigger={collapseIcon => (
|
||||
<div className='flex grow items-center justify-between pr-4'>
|
||||
<div className='flex items-center'>
|
||||
<div className='system-sm-semibold-uppercase mr-0.5 text-text-secondary'>
|
||||
{t('workflow.nodes.knowledgeRetrieval.metadata.title')}
|
||||
</div>
|
||||
<Tooltip
|
||||
popupContent={(
|
||||
<div className='w-[200px]'>
|
||||
{t('workflow.nodes.knowledgeRetrieval.metadata.tip')}
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
{collapseIcon}
|
||||
</div>
|
||||
<div className='flex items-center'>
|
||||
<MetadataFilterSelector
|
||||
value={metadataFilterMode}
|
||||
onSelect={handleMetadataFilterModeChangeWrapped}
|
||||
/>
|
||||
{
|
||||
metadataFilterMode === MetadataFilteringModeEnum.manual && (
|
||||
<div className='ml-1'>
|
||||
<MetadataTrigger {...restProps} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
>
|
||||
<>
|
||||
{
|
||||
metadataFilterMode === MetadataFilteringModeEnum.automatic && (
|
||||
<>
|
||||
<div className='body-xs-regular px-4 text-text-tertiary'>
|
||||
{t('workflow.nodes.knowledgeRetrieval.metadata.options.automatic.desc')}
|
||||
</div>
|
||||
<div className='mt-1 px-4'>
|
||||
<ModelParameterModal
|
||||
portalToFollowElemContentClassName='z-[50]'
|
||||
popupClassName='!w-[387px]'
|
||||
isInWorkflow
|
||||
isAdvancedMode={true}
|
||||
provider={metadataModelConfig?.provider || ''}
|
||||
completionParams={metadataModelConfig?.completion_params || { temperature: 0.7 }}
|
||||
modelId={metadataModelConfig?.name || ''}
|
||||
setModel={handleMetadataModelChange || noop}
|
||||
onCompletionParamsChange={handleMetadataCompletionParamsChange || noop}
|
||||
hideDebugWithMultipleModel
|
||||
debugWithMultipleModel={false}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
</>
|
||||
</Collapse>
|
||||
)
|
||||
}
|
||||
|
||||
export default MetadataFilter
|
||||
@@ -0,0 +1,106 @@
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
RiArrowDownSLine,
|
||||
RiCheckLine,
|
||||
} from '@remixicon/react'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { MetadataFilteringModeEnum } from '@/app/components/workflow/nodes/knowledge-retrieval/types'
|
||||
|
||||
type MetadataFilterSelectorProps = {
|
||||
value?: MetadataFilteringModeEnum
|
||||
onSelect: (value: MetadataFilteringModeEnum) => void
|
||||
}
|
||||
const MetadataFilterSelector = ({
|
||||
value = MetadataFilteringModeEnum.disabled,
|
||||
onSelect,
|
||||
}: MetadataFilterSelectorProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [open, setOpen] = useState(false)
|
||||
const options = [
|
||||
{
|
||||
key: MetadataFilteringModeEnum.disabled,
|
||||
value: t('workflow.nodes.knowledgeRetrieval.metadata.options.disabled.title'),
|
||||
desc: t('workflow.nodes.knowledgeRetrieval.metadata.options.disabled.subTitle'),
|
||||
},
|
||||
{
|
||||
key: MetadataFilteringModeEnum.automatic,
|
||||
value: t('workflow.nodes.knowledgeRetrieval.metadata.options.automatic.title'),
|
||||
desc: t('workflow.nodes.knowledgeRetrieval.metadata.options.automatic.subTitle'),
|
||||
},
|
||||
{
|
||||
key: MetadataFilteringModeEnum.manual,
|
||||
value: t('workflow.nodes.knowledgeRetrieval.metadata.options.manual.title'),
|
||||
desc: t('workflow.nodes.knowledgeRetrieval.metadata.options.manual.subTitle'),
|
||||
},
|
||||
]
|
||||
|
||||
const selectedOption = options.find(option => option.key === value)!
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
placement='bottom-end'
|
||||
offset={{
|
||||
mainAxis: 4,
|
||||
crossAxis: 0,
|
||||
}}
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
>
|
||||
<PortalToFollowElemTrigger
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
setOpen(!open)
|
||||
}}
|
||||
asChild
|
||||
>
|
||||
<Button
|
||||
variant='secondary'
|
||||
size='small'
|
||||
>
|
||||
{selectedOption.value}
|
||||
<RiArrowDownSLine className='h-3.5 w-3.5' />
|
||||
</Button>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className='z-10'>
|
||||
<div className='w-[280px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg'>
|
||||
{
|
||||
options.map(option => (
|
||||
<div
|
||||
key={option.key}
|
||||
className='flex cursor-pointer rounded-lg p-2 pr-3 hover:bg-state-base-hover'
|
||||
onClick={() => {
|
||||
onSelect(option.key)
|
||||
setOpen(false)
|
||||
}}
|
||||
>
|
||||
<div className='w-4 shrink-0'>
|
||||
{
|
||||
option.key === value && (
|
||||
<RiCheckLine className='h-4 w-4 text-text-accent' />
|
||||
)
|
||||
}
|
||||
</div>
|
||||
<div className='grow'>
|
||||
<div className='system-sm-semibold text-text-secondary'>
|
||||
{option.value}
|
||||
</div>
|
||||
<div className='system-xs-regular text-text-tertiary'>
|
||||
{option.desc}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
)
|
||||
}
|
||||
|
||||
export default MetadataFilterSelector
|
||||
@@ -0,0 +1,39 @@
|
||||
import { memo } from 'react'
|
||||
import {
|
||||
RiHashtag,
|
||||
RiTextSnippet,
|
||||
RiTimeLine,
|
||||
} from '@remixicon/react'
|
||||
import { MetadataFilteringVariableType } from '@/app/components/workflow/nodes/knowledge-retrieval/types'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
type MetadataIconProps = {
|
||||
type?: MetadataFilteringVariableType
|
||||
className?: string
|
||||
}
|
||||
const MetadataIcon = ({
|
||||
type,
|
||||
className,
|
||||
}: MetadataIconProps) => {
|
||||
return (
|
||||
<>
|
||||
{
|
||||
(type === MetadataFilteringVariableType.string || type === MetadataFilteringVariableType.select) && (
|
||||
<RiTextSnippet className={cn('h-3.5 w-3.5', className)} />
|
||||
)
|
||||
}
|
||||
{
|
||||
type === MetadataFilteringVariableType.number && (
|
||||
<RiHashtag className={cn('h-3.5 w-3.5', className)} />
|
||||
)
|
||||
}
|
||||
{
|
||||
type === MetadataFilteringVariableType.time && (
|
||||
<RiTimeLine className={cn('h-3.5 w-3.5', className)} />
|
||||
)
|
||||
}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(MetadataIcon)
|
||||
@@ -0,0 +1,51 @@
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiCloseLine } from '@remixicon/react'
|
||||
import AddCondition from './add-condition'
|
||||
import ConditionList from './condition-list'
|
||||
import type { MetadataShape } from '@/app/components/workflow/nodes/knowledge-retrieval/types'
|
||||
|
||||
type MetadataPanelProps = {
|
||||
onCancel: () => void
|
||||
} & MetadataShape
|
||||
const MetadataPanel = ({
|
||||
metadataFilteringConditions,
|
||||
metadataList,
|
||||
onCancel,
|
||||
handleAddCondition,
|
||||
...restProps
|
||||
}: MetadataPanelProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className='w-[420px] rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-2xl'>
|
||||
<div className='relative px-3 pt-3.5'>
|
||||
<div className='system-xl-semibold text-text-primary'>
|
||||
{t('workflow.nodes.knowledgeRetrieval.metadata.panel.title')}
|
||||
</div>
|
||||
<div
|
||||
className='absolute bottom-0 right-2.5 flex h-8 w-8 cursor-pointer items-center justify-center'
|
||||
onClick={onCancel}
|
||||
>
|
||||
<RiCloseLine className='h-4 w-4 text-text-tertiary' />
|
||||
</div>
|
||||
</div>
|
||||
<div className='px-1 py-2'>
|
||||
<div className='px-3 py-1'>
|
||||
<div className='pb-2'>
|
||||
<ConditionList
|
||||
metadataList={metadataList}
|
||||
metadataFilteringConditions={metadataFilteringConditions}
|
||||
{...restProps}
|
||||
/>
|
||||
</div>
|
||||
<AddCondition
|
||||
metadataList={metadataList}
|
||||
handleAddCondition={handleAddCondition}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default MetadataPanel
|
||||
@@ -0,0 +1,68 @@
|
||||
import {
|
||||
useEffect,
|
||||
useState,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiFilter3Line } from '@remixicon/react'
|
||||
import MetadataPanel from './metadata-panel'
|
||||
import Button from '@/app/components/base/button'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import type { MetadataShape } from '@/app/components/workflow/nodes/knowledge-retrieval/types'
|
||||
|
||||
const MetadataTrigger = ({
|
||||
metadataFilteringConditions,
|
||||
metadataList = [],
|
||||
handleRemoveCondition,
|
||||
selectedDatasetsLoaded,
|
||||
...restProps
|
||||
}: MetadataShape) => {
|
||||
const { t } = useTranslation()
|
||||
const [open, setOpen] = useState(false)
|
||||
const conditions = metadataFilteringConditions?.conditions || []
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedDatasetsLoaded) {
|
||||
conditions.forEach((condition) => {
|
||||
if (!metadataList.find(metadata => metadata.name === condition.name))
|
||||
handleRemoveCondition(condition.id)
|
||||
})
|
||||
}
|
||||
}, [metadataList, handleRemoveCondition, selectedDatasetsLoaded])
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
placement='left'
|
||||
offset={4}
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
>
|
||||
<PortalToFollowElemTrigger onClick={() => setOpen(!open)}>
|
||||
<Button
|
||||
variant='secondary-accent'
|
||||
size='small'
|
||||
>
|
||||
<RiFilter3Line className='mr-1 h-3.5 w-3.5' />
|
||||
{t('workflow.nodes.knowledgeRetrieval.metadata.panel.conditions')}
|
||||
<div className='system-2xs-medium-uppercase ml-1 flex items-center rounded-[5px] border border-divider-deep px-1 text-text-tertiary'>
|
||||
{metadataFilteringConditions?.conditions.length || 0}
|
||||
</div>
|
||||
</Button>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className='z-10'>
|
||||
<MetadataPanel
|
||||
metadataFilteringConditions={metadataFilteringConditions}
|
||||
onCancel={() => setOpen(false)}
|
||||
metadataList={metadataList}
|
||||
handleRemoveCondition={handleRemoveCondition}
|
||||
{...restProps}
|
||||
/>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
)
|
||||
}
|
||||
|
||||
export default MetadataTrigger
|
||||
@@ -0,0 +1,156 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback, useMemo } from 'react'
|
||||
import { RiEqualizer2Line } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types'
|
||||
import type { ModelConfig } from '../../../types'
|
||||
import cn from '@/utils/classnames'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import ConfigRetrievalContent from '@/app/components/app/configuration/dataset-config/params-config/config-content'
|
||||
import { RETRIEVE_TYPE } from '@/types/app'
|
||||
import { DATASET_DEFAULT } from '@/config'
|
||||
import Button from '@/app/components/base/button'
|
||||
import type { DatasetConfigs } from '@/models/debug'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
|
||||
type Props = {
|
||||
payload: {
|
||||
retrieval_mode: RETRIEVE_TYPE
|
||||
multiple_retrieval_config?: MultipleRetrievalConfig
|
||||
single_retrieval_config?: SingleRetrievalConfig
|
||||
}
|
||||
onRetrievalModeChange: (mode: RETRIEVE_TYPE) => void
|
||||
onMultipleRetrievalConfigChange: (config: MultipleRetrievalConfig) => void
|
||||
singleRetrievalModelConfig?: ModelConfig
|
||||
onSingleRetrievalModelChange?: (config: ModelConfig) => void
|
||||
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
|
||||
readonly?: boolean
|
||||
rerankModalOpen: boolean
|
||||
onRerankModelOpenChange: (open: boolean) => void
|
||||
selectedDatasets: DataSet[]
|
||||
}
|
||||
|
||||
const RetrievalConfig: FC<Props> = ({
|
||||
payload,
|
||||
onRetrievalModeChange,
|
||||
onMultipleRetrievalConfigChange,
|
||||
singleRetrievalModelConfig,
|
||||
onSingleRetrievalModelChange,
|
||||
onSingleRetrievalModelParamsChange,
|
||||
readonly,
|
||||
rerankModalOpen,
|
||||
onRerankModelOpenChange,
|
||||
selectedDatasets,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { retrieval_mode, multiple_retrieval_config } = payload
|
||||
|
||||
const handleOpen = useCallback((newOpen: boolean) => {
|
||||
onRerankModelOpenChange(newOpen)
|
||||
}, [onRerankModelOpenChange])
|
||||
|
||||
const datasetConfigs = useMemo(() => {
|
||||
const {
|
||||
reranking_model,
|
||||
top_k,
|
||||
score_threshold,
|
||||
reranking_mode,
|
||||
weights,
|
||||
reranking_enable,
|
||||
} = multiple_retrieval_config || {}
|
||||
|
||||
return {
|
||||
retrieval_model: retrieval_mode,
|
||||
reranking_model: (reranking_model?.provider && reranking_model?.model)
|
||||
? {
|
||||
reranking_provider_name: reranking_model?.provider,
|
||||
reranking_model_name: reranking_model?.model,
|
||||
}
|
||||
: {
|
||||
reranking_provider_name: '',
|
||||
reranking_model_name: '',
|
||||
},
|
||||
top_k: top_k || DATASET_DEFAULT.top_k,
|
||||
score_threshold_enabled: !(score_threshold === undefined || score_threshold === null),
|
||||
score_threshold,
|
||||
datasets: {
|
||||
datasets: [],
|
||||
},
|
||||
reranking_mode,
|
||||
weights,
|
||||
reranking_enable,
|
||||
}
|
||||
}, [retrieval_mode, multiple_retrieval_config])
|
||||
|
||||
const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => {
|
||||
// Legacy code, for compatibility, have to keep it
|
||||
if (isRetrievalModeChange) {
|
||||
onRetrievalModeChange(configs.retrieval_model)
|
||||
return
|
||||
}
|
||||
onMultipleRetrievalConfigChange({
|
||||
top_k: configs.top_k,
|
||||
score_threshold: configs.score_threshold_enabled ? (configs.score_threshold ?? DATASET_DEFAULT.score_threshold) : null,
|
||||
reranking_model: retrieval_mode === RETRIEVE_TYPE.oneWay
|
||||
? undefined
|
||||
// eslint-disable-next-line sonarjs/no-nested-conditional
|
||||
: (!configs.reranking_model?.reranking_provider_name
|
||||
? undefined
|
||||
: {
|
||||
provider: configs.reranking_model?.reranking_provider_name,
|
||||
model: configs.reranking_model?.reranking_model_name,
|
||||
}),
|
||||
reranking_mode: configs.reranking_mode,
|
||||
weights: configs.weights,
|
||||
reranking_enable: configs.reranking_enable,
|
||||
})
|
||||
}, [onMultipleRetrievalConfigChange, retrieval_mode, onRetrievalModeChange])
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
open={rerankModalOpen}
|
||||
onOpenChange={handleOpen}
|
||||
placement='bottom-end'
|
||||
offset={{
|
||||
crossAxis: -2,
|
||||
}}
|
||||
>
|
||||
<PortalToFollowElemTrigger
|
||||
onClick={() => {
|
||||
if (readonly)
|
||||
return
|
||||
handleOpen(!rerankModalOpen)
|
||||
}}
|
||||
>
|
||||
<Button
|
||||
variant='ghost'
|
||||
size='small'
|
||||
disabled={readonly}
|
||||
className={cn(rerankModalOpen && 'bg-components-button-ghost-bg-hover')}
|
||||
>
|
||||
<RiEqualizer2Line className='mr-1 h-3.5 w-3.5' />
|
||||
{t('dataset.retrievalSettings')}
|
||||
</Button>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent style={{ zIndex: 1001 }}>
|
||||
<div className='w-[404px] rounded-2xl border border-components-panel-border bg-components-panel-bg px-4 pb-4 pt-3 shadow-xl'>
|
||||
<ConfigRetrievalContent
|
||||
datasetConfigs={datasetConfigs}
|
||||
onChange={handleChange}
|
||||
selectedDatasets={selectedDatasets}
|
||||
isInWorkflow
|
||||
singleRetrievalModelConfig={singleRetrievalModelConfig}
|
||||
onSingleRetrievalModelChange={onSingleRetrievalModelChange}
|
||||
onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange}
|
||||
/>
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
)
|
||||
}
|
||||
export default React.memo(RetrievalConfig)
|
||||
@@ -0,0 +1,52 @@
|
||||
import type { NodeDefault } from '../../types'
|
||||
import type { KnowledgeRetrievalNodeType } from './types'
|
||||
import { checkoutRerankModelConfiguredInRetrievalSettings } from './utils'
|
||||
import { DATASET_DEFAULT } from '@/config'
|
||||
import { RETRIEVE_TYPE } from '@/types/app'
|
||||
import { genNodeMetaData } from '@/app/components/workflow/utils'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
const i18nPrefix = 'workflow'
|
||||
|
||||
const metaData = genNodeMetaData({
|
||||
sort: 2,
|
||||
type: BlockEnum.KnowledgeRetrieval,
|
||||
})
|
||||
const nodeDefault: NodeDefault<KnowledgeRetrievalNodeType> = {
|
||||
metaData,
|
||||
defaultValue: {
|
||||
query_variable_selector: [],
|
||||
dataset_ids: [],
|
||||
retrieval_mode: RETRIEVE_TYPE.multiWay,
|
||||
multiple_retrieval_config: {
|
||||
top_k: DATASET_DEFAULT.top_k,
|
||||
score_threshold: undefined,
|
||||
reranking_enable: false,
|
||||
},
|
||||
},
|
||||
checkValid(payload: KnowledgeRetrievalNodeType, t: any) {
|
||||
let errorMessages = ''
|
||||
if (!errorMessages && (!payload.query_variable_selector || payload.query_variable_selector.length === 0))
|
||||
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.queryVariable`) })
|
||||
|
||||
if (!errorMessages && (!payload.dataset_ids || payload.dataset_ids.length === 0))
|
||||
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.knowledge`) })
|
||||
|
||||
if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.oneWay && !payload.single_retrieval_config?.model?.provider)
|
||||
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t('common.modelProvider.systemReasoningModel.key') })
|
||||
|
||||
const { _datasets, multiple_retrieval_config, retrieval_mode } = payload
|
||||
if (retrieval_mode === RETRIEVE_TYPE.multiWay) {
|
||||
const checked = checkoutRerankModelConfiguredInRetrievalSettings(_datasets || [], multiple_retrieval_config)
|
||||
|
||||
if (!errorMessages && !checked)
|
||||
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })
|
||||
}
|
||||
|
||||
return {
|
||||
isValid: !errorMessages,
|
||||
errorMessage: errorMessages,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
export default nodeDefault
|
||||
@@ -0,0 +1,14 @@
|
||||
import { useMemo } from 'react'
|
||||
import { getSelectedDatasetsMode } from './utils'
|
||||
import type {
|
||||
DataSet,
|
||||
SelectedDatasetsMode,
|
||||
} from '@/models/datasets'
|
||||
|
||||
export const useSelectedDatasetsMode = (datasets: DataSet[]) => {
|
||||
const selectedDatasetsMode: SelectedDatasetsMode = useMemo(() => {
|
||||
return getSelectedDatasetsMode(datasets)
|
||||
}, [datasets])
|
||||
|
||||
return selectedDatasetsMode
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
import { type FC, useEffect, useState } from 'react'
|
||||
import React from 'react'
|
||||
import type { KnowledgeRetrievalNodeType } from './types'
|
||||
import type { NodeProps } from '@/app/components/workflow/types'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { useDatasetsDetailStore } from '../../datasets-detail-store/store'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
|
||||
const Node: FC<NodeProps<KnowledgeRetrievalNodeType>> = ({
|
||||
data,
|
||||
}) => {
|
||||
const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
|
||||
const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail)
|
||||
|
||||
useEffect(() => {
|
||||
if (data.dataset_ids?.length > 0) {
|
||||
const dataSetsWithDetail = data.dataset_ids.reduce<DataSet[]>((acc, id) => {
|
||||
if (datasetsDetail[id])
|
||||
acc.push(datasetsDetail[id])
|
||||
return acc
|
||||
}, [])
|
||||
setSelectedDatasets(dataSetsWithDetail)
|
||||
}
|
||||
else {
|
||||
setSelectedDatasets([])
|
||||
}
|
||||
}, [data.dataset_ids, datasetsDetail])
|
||||
|
||||
if (!selectedDatasets.length)
|
||||
return null
|
||||
|
||||
return (
|
||||
<div className='mb-1 px-3 py-1'>
|
||||
<div className='space-y-0.5'>
|
||||
{selectedDatasets.map(({ id, name, icon_info }) => (
|
||||
<div key={id} className='flex h-[26px] items-center gap-x-1 rounded-md bg-workflow-block-parma-bg px-1'>
|
||||
<AppIcon
|
||||
size='xs'
|
||||
iconType={icon_info.icon_type}
|
||||
icon={icon_info.icon}
|
||||
background={icon_info.icon_type === 'image' ? undefined : icon_info.icon_background}
|
||||
imageUrl={icon_info.icon_type === 'image' ? icon_info.icon_url : undefined}
|
||||
className='shrink-0'
|
||||
/>
|
||||
<div className='system-xs-regular w-0 grow truncate text-text-secondary'>
|
||||
{name}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(Node)
|
||||
@@ -0,0 +1,183 @@
|
||||
import type { FC } from 'react'
|
||||
import {
|
||||
memo,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import { intersectionBy } from 'lodash-es'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import VarReferencePicker from '../_base/components/variable/var-reference-picker'
|
||||
import useConfig from './use-config'
|
||||
import RetrievalConfig from './components/retrieval-config'
|
||||
import AddKnowledge from './components/add-dataset'
|
||||
import DatasetList from './components/dataset-list'
|
||||
import MetadataFilter from './components/metadata/metadata-filter'
|
||||
import type { KnowledgeRetrievalNodeType } from './types'
|
||||
import Field from '@/app/components/workflow/nodes/_base/components/field'
|
||||
import Split from '@/app/components/workflow/nodes/_base/components/split'
|
||||
import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars'
|
||||
import type { NodePanelProps } from '@/app/components/workflow/types'
|
||||
|
||||
const i18nPrefix = 'workflow.nodes.knowledgeRetrieval'
|
||||
|
||||
const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
|
||||
id,
|
||||
data,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const {
|
||||
readOnly,
|
||||
inputs,
|
||||
handleQueryVarChange,
|
||||
filterVar,
|
||||
handleModelChanged,
|
||||
handleCompletionParamsChange,
|
||||
handleRetrievalModeChange,
|
||||
handleMultipleRetrievalConfigChange,
|
||||
selectedDatasets,
|
||||
selectedDatasetsLoaded,
|
||||
handleOnDatasetsChange,
|
||||
rerankModelOpen,
|
||||
setRerankModelOpen,
|
||||
handleAddCondition,
|
||||
handleMetadataFilterModeChange,
|
||||
handleRemoveCondition,
|
||||
handleToggleConditionLogicalOperator,
|
||||
handleUpdateCondition,
|
||||
handleMetadataModelChange,
|
||||
handleMetadataCompletionParamsChange,
|
||||
availableStringVars,
|
||||
availableStringNodesWithParent,
|
||||
availableNumberVars,
|
||||
availableNumberNodesWithParent,
|
||||
} = useConfig(id, data)
|
||||
|
||||
const metadataList = useMemo(() => {
|
||||
return intersectionBy(...selectedDatasets.filter((dataset) => {
|
||||
return !!dataset.doc_metadata
|
||||
}).map((dataset) => {
|
||||
return dataset.doc_metadata!
|
||||
}), 'name')
|
||||
}, [selectedDatasets])
|
||||
|
||||
return (
|
||||
<div className='pt-2'>
|
||||
<div className='space-y-4 px-4 pb-2'>
|
||||
<Field
|
||||
title={t(`${i18nPrefix}.queryVariable`)}
|
||||
required
|
||||
>
|
||||
<VarReferencePicker
|
||||
nodeId={id}
|
||||
readonly={readOnly}
|
||||
isShowNodeName
|
||||
value={inputs.query_variable_selector}
|
||||
onChange={handleQueryVarChange}
|
||||
filterVar={filterVar}
|
||||
/>
|
||||
</Field>
|
||||
|
||||
<Field
|
||||
title={t(`${i18nPrefix}.knowledge`)}
|
||||
required
|
||||
operations={
|
||||
<div className='flex items-center space-x-1'>
|
||||
<RetrievalConfig
|
||||
payload={{
|
||||
retrieval_mode: inputs.retrieval_mode,
|
||||
multiple_retrieval_config: inputs.multiple_retrieval_config,
|
||||
single_retrieval_config: inputs.single_retrieval_config,
|
||||
}}
|
||||
onRetrievalModeChange={handleRetrievalModeChange}
|
||||
onMultipleRetrievalConfigChange={handleMultipleRetrievalConfigChange}
|
||||
singleRetrievalModelConfig={inputs.single_retrieval_config?.model}
|
||||
onSingleRetrievalModelChange={handleModelChanged as any}
|
||||
onSingleRetrievalModelParamsChange={handleCompletionParamsChange}
|
||||
readonly={readOnly || !selectedDatasets.length}
|
||||
rerankModalOpen={rerankModelOpen}
|
||||
onRerankModelOpenChange={setRerankModelOpen}
|
||||
selectedDatasets={selectedDatasets}
|
||||
/>
|
||||
{!readOnly && (<div className='h-3 w-px bg-divider-regular'></div>)}
|
||||
{!readOnly && (
|
||||
<AddKnowledge
|
||||
selectedIds={inputs.dataset_ids}
|
||||
onChange={handleOnDatasetsChange}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<DatasetList
|
||||
list={selectedDatasets}
|
||||
onChange={handleOnDatasetsChange}
|
||||
readonly={readOnly}
|
||||
/>
|
||||
</Field>
|
||||
</div>
|
||||
<div className='mb-2 py-2'>
|
||||
<MetadataFilter
|
||||
metadataList={metadataList}
|
||||
selectedDatasetsLoaded={selectedDatasetsLoaded}
|
||||
metadataFilterMode={inputs.metadata_filtering_mode}
|
||||
metadataFilteringConditions={inputs.metadata_filtering_conditions}
|
||||
handleAddCondition={handleAddCondition}
|
||||
handleMetadataFilterModeChange={handleMetadataFilterModeChange}
|
||||
handleRemoveCondition={handleRemoveCondition}
|
||||
handleToggleConditionLogicalOperator={handleToggleConditionLogicalOperator}
|
||||
handleUpdateCondition={handleUpdateCondition}
|
||||
metadataModelConfig={inputs.metadata_model_config}
|
||||
handleMetadataModelChange={handleMetadataModelChange}
|
||||
handleMetadataCompletionParamsChange={handleMetadataCompletionParamsChange}
|
||||
availableStringVars={availableStringVars}
|
||||
availableStringNodesWithParent={availableStringNodesWithParent}
|
||||
availableNumberVars={availableNumberVars}
|
||||
availableNumberNodesWithParent={availableNumberNodesWithParent}
|
||||
/>
|
||||
</div>
|
||||
<Split />
|
||||
<div>
|
||||
<OutputVars>
|
||||
<>
|
||||
<VarItem
|
||||
name='result'
|
||||
type='Array[Object]'
|
||||
description={t(`${i18nPrefix}.outputVars.output`)}
|
||||
subItems={[
|
||||
{
|
||||
name: 'content',
|
||||
type: 'string',
|
||||
description: t(`${i18nPrefix}.outputVars.content`),
|
||||
},
|
||||
// url, title, link like bing search reference result: link, link page title, link page icon
|
||||
{
|
||||
name: 'title',
|
||||
type: 'string',
|
||||
description: t(`${i18nPrefix}.outputVars.title`),
|
||||
},
|
||||
{
|
||||
name: 'url',
|
||||
type: 'string',
|
||||
description: t(`${i18nPrefix}.outputVars.url`),
|
||||
},
|
||||
{
|
||||
name: 'icon',
|
||||
type: 'string',
|
||||
description: t(`${i18nPrefix}.outputVars.icon`),
|
||||
},
|
||||
{
|
||||
name: 'metadata',
|
||||
type: 'object',
|
||||
description: t(`${i18nPrefix}.outputVars.metadata`),
|
||||
},
|
||||
]}
|
||||
/>
|
||||
|
||||
</>
|
||||
</OutputVars>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(Panel)
|
||||
@@ -0,0 +1,133 @@
|
||||
import type {
|
||||
CommonNodeType,
|
||||
ModelConfig,
|
||||
Node,
|
||||
NodeOutPutVar,
|
||||
ValueSelector,
|
||||
} from '@/app/components/workflow/types'
|
||||
import type { RETRIEVE_TYPE } from '@/types/app'
|
||||
import type {
|
||||
DataSet,
|
||||
MetadataInDoc,
|
||||
RerankingModeEnum,
|
||||
WeightedScoreEnum,
|
||||
} from '@/models/datasets'
|
||||
|
||||
export type MultipleRetrievalConfig = {
|
||||
top_k: number
|
||||
score_threshold: number | null | undefined
|
||||
reranking_model?: {
|
||||
provider: string
|
||||
model: string
|
||||
}
|
||||
reranking_mode?: RerankingModeEnum
|
||||
weights?: {
|
||||
weight_type: WeightedScoreEnum
|
||||
vector_setting: {
|
||||
vector_weight: number
|
||||
embedding_provider_name: string
|
||||
embedding_model_name: string
|
||||
}
|
||||
keyword_setting: {
|
||||
keyword_weight: number
|
||||
}
|
||||
}
|
||||
reranking_enable?: boolean
|
||||
}
|
||||
|
||||
export type SingleRetrievalConfig = {
|
||||
model: ModelConfig
|
||||
}
|
||||
|
||||
export enum LogicalOperator {
|
||||
and = 'and',
|
||||
or = 'or',
|
||||
}
|
||||
|
||||
export enum ComparisonOperator {
|
||||
contains = 'contains',
|
||||
notContains = 'not contains',
|
||||
startWith = 'start with',
|
||||
endWith = 'end with',
|
||||
is = 'is',
|
||||
isNot = 'is not',
|
||||
empty = 'empty',
|
||||
notEmpty = 'not empty',
|
||||
equal = '=',
|
||||
notEqual = '≠',
|
||||
largerThan = '>',
|
||||
lessThan = '<',
|
||||
largerThanOrEqual = '≥',
|
||||
lessThanOrEqual = '≤',
|
||||
isNull = 'is null',
|
||||
isNotNull = 'is not null',
|
||||
in = 'in',
|
||||
notIn = 'not in',
|
||||
allOf = 'all of',
|
||||
exists = 'exists',
|
||||
notExists = 'not exists',
|
||||
before = 'before',
|
||||
after = 'after',
|
||||
}
|
||||
|
||||
export enum MetadataFilteringModeEnum {
|
||||
disabled = 'disabled',
|
||||
automatic = 'automatic',
|
||||
manual = 'manual',
|
||||
}
|
||||
|
||||
export enum MetadataFilteringVariableType {
|
||||
string = 'string',
|
||||
number = 'number',
|
||||
time = 'time',
|
||||
select = 'select',
|
||||
}
|
||||
|
||||
export type MetadataFilteringCondition = {
|
||||
id: string
|
||||
name: string
|
||||
comparison_operator: ComparisonOperator
|
||||
value?: string | number
|
||||
}
|
||||
|
||||
export type MetadataFilteringConditions = {
|
||||
logical_operator: LogicalOperator
|
||||
conditions: MetadataFilteringCondition[]
|
||||
}
|
||||
|
||||
export type KnowledgeRetrievalNodeType = CommonNodeType & {
|
||||
query_variable_selector: ValueSelector
|
||||
dataset_ids: string[]
|
||||
retrieval_mode: RETRIEVE_TYPE
|
||||
multiple_retrieval_config?: MultipleRetrievalConfig
|
||||
single_retrieval_config?: SingleRetrievalConfig
|
||||
_datasets?: DataSet[]
|
||||
metadata_filtering_mode?: MetadataFilteringModeEnum
|
||||
metadata_filtering_conditions?: MetadataFilteringConditions
|
||||
metadata_model_config?: ModelConfig
|
||||
}
|
||||
|
||||
export type HandleAddCondition = (metadataItem: MetadataInDoc) => void
|
||||
export type HandleRemoveCondition = (id: string) => void
|
||||
export type HandleUpdateCondition = (id: string, newCondition: MetadataFilteringCondition) => void
|
||||
export type HandleToggleConditionLogicalOperator = () => void
|
||||
|
||||
export type MetadataShape = {
|
||||
metadataList?: MetadataInDoc[]
|
||||
selectedDatasetsLoaded?: boolean
|
||||
metadataFilteringConditions?: MetadataFilteringConditions
|
||||
handleAddCondition: HandleAddCondition
|
||||
handleRemoveCondition: HandleRemoveCondition
|
||||
handleToggleConditionLogicalOperator: HandleToggleConditionLogicalOperator
|
||||
handleUpdateCondition: HandleUpdateCondition
|
||||
metadataModelConfig?: ModelConfig
|
||||
handleMetadataModelChange?: (model: { modelId: string; provider: string; mode?: string; features?: string[] }) => void
|
||||
handleMetadataCompletionParamsChange?: (params: Record<string, any>) => void
|
||||
availableStringVars?: NodeOutPutVar[]
|
||||
availableStringNodesWithParent?: Node[]
|
||||
availableNumberVars?: NodeOutPutVar[]
|
||||
availableNumberNodesWithParent?: Node[]
|
||||
isCommonVariable?: boolean
|
||||
availableCommonStringVars?: { name: string; type: string; value: string }[]
|
||||
availableCommonNumberVars?: { name: string; type: string; value: string }[]
|
||||
}
|
||||
@@ -0,0 +1,416 @@
|
||||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react'
|
||||
import { produce } from 'immer'
|
||||
import { isEqual } from 'lodash-es'
|
||||
import { v4 as uuid4 } from 'uuid'
|
||||
import type { ValueSelector, Var } from '../../types'
|
||||
import { BlockEnum, VarType } from '../../types'
|
||||
import {
|
||||
useIsChatMode,
|
||||
useNodesReadOnly,
|
||||
useWorkflow,
|
||||
} from '../../hooks'
|
||||
import type {
|
||||
HandleAddCondition,
|
||||
HandleRemoveCondition,
|
||||
HandleToggleConditionLogicalOperator,
|
||||
HandleUpdateCondition,
|
||||
KnowledgeRetrievalNodeType,
|
||||
MetadataFilteringModeEnum,
|
||||
MultipleRetrievalConfig,
|
||||
} from './types'
|
||||
import {
|
||||
ComparisonOperator,
|
||||
LogicalOperator,
|
||||
MetadataFilteringVariableType,
|
||||
} from './types'
|
||||
import {
|
||||
getMultipleRetrievalConfig,
|
||||
getSelectedDatasetsMode,
|
||||
} from './utils'
|
||||
import { AppModeEnum, RETRIEVE_TYPE } from '@/types/app'
|
||||
import { DATASET_DEFAULT } from '@/config'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
|
||||
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use-available-var-list'
|
||||
import { useDatasetsDetailStore } from '../../datasets-detail-store/store'
|
||||
|
||||
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
const { nodesReadOnly: readOnly } = useNodesReadOnly()
|
||||
const isChatMode = useIsChatMode()
|
||||
const { getBeforeNodesInSameBranch } = useWorkflow()
|
||||
const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
|
||||
const startNodeId = startNode?.id
|
||||
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
|
||||
const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail)
|
||||
|
||||
const inputRef = useRef(inputs)
|
||||
|
||||
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
|
||||
const newInputs = produce(s, (draft) => {
|
||||
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
|
||||
delete draft.single_retrieval_config
|
||||
else
|
||||
delete draft.multiple_retrieval_config
|
||||
})
|
||||
// not work in pass to draft...
|
||||
doSetInputs(newInputs)
|
||||
inputRef.current = newInputs
|
||||
}, [doSetInputs])
|
||||
|
||||
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.query_variable_selector = newVar as ValueSelector
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
|
||||
const {
|
||||
currentProvider,
|
||||
currentModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
|
||||
|
||||
const {
|
||||
modelList: rerankModelList,
|
||||
defaultModel: rerankDefaultModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const {
|
||||
currentModel: currentRerankModel,
|
||||
currentProvider: currentRerankProvider,
|
||||
} = useCurrentProviderAndModel(
|
||||
rerankModelList,
|
||||
rerankDefaultModel
|
||||
? {
|
||||
...rerankDefaultModel,
|
||||
provider: rerankDefaultModel.provider.provider,
|
||||
}
|
||||
: undefined,
|
||||
)
|
||||
|
||||
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
|
||||
const newInputs = produce(inputRef.current, (draft) => {
|
||||
if (!draft.single_retrieval_config) {
|
||||
draft.single_retrieval_config = {
|
||||
model: {
|
||||
provider: '',
|
||||
name: '',
|
||||
mode: '',
|
||||
completion_params: {},
|
||||
},
|
||||
}
|
||||
}
|
||||
const draftModel = draft.single_retrieval_config?.model
|
||||
draftModel.provider = model.provider
|
||||
draftModel.name = model.modelId
|
||||
draftModel.mode = model.mode!
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [setInputs])
|
||||
|
||||
const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => {
|
||||
// inputRef.current.single_retrieval_config?.model is old when change the provider...
|
||||
if (isEqual(newParams, inputRef.current.single_retrieval_config?.model.completion_params))
|
||||
return
|
||||
|
||||
const newInputs = produce(inputRef.current, (draft) => {
|
||||
if (!draft.single_retrieval_config) {
|
||||
draft.single_retrieval_config = {
|
||||
model: {
|
||||
provider: '',
|
||||
name: '',
|
||||
mode: '',
|
||||
completion_params: {},
|
||||
},
|
||||
}
|
||||
}
|
||||
draft.single_retrieval_config.model.completion_params = newParams
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [setInputs])
|
||||
|
||||
// set defaults models
|
||||
useEffect(() => {
|
||||
const inputs = inputRef.current
|
||||
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel)
|
||||
return
|
||||
|
||||
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
|
||||
return
|
||||
|
||||
const newInput = produce(inputs, (draft) => {
|
||||
if (currentProvider?.provider && currentModel?.model) {
|
||||
const hasSetModel = draft.single_retrieval_config?.model?.provider
|
||||
if (!hasSetModel) {
|
||||
draft.single_retrieval_config = {
|
||||
model: {
|
||||
provider: currentProvider?.provider,
|
||||
name: currentModel?.model,
|
||||
mode: currentModel?.model_properties?.mode as string,
|
||||
completion_params: {},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
draft.multiple_retrieval_config = {
|
||||
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
|
||||
score_threshold: multipleRetrievalConfig?.score_threshold,
|
||||
reranking_model: multipleRetrievalConfig?.reranking_model,
|
||||
reranking_mode: multipleRetrievalConfig?.reranking_mode,
|
||||
weights: multipleRetrievalConfig?.weights,
|
||||
reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined
|
||||
? multipleRetrievalConfig.reranking_enable
|
||||
: Boolean(currentRerankModel && rerankDefaultModel),
|
||||
}
|
||||
})
|
||||
setInputs(newInput)
|
||||
}, [currentProvider?.provider, currentModel, currentRerankModel, rerankDefaultModel])
|
||||
const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
|
||||
const [rerankModelOpen, setRerankModelOpen] = useState(false)
|
||||
const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.retrieval_mode = newMode
|
||||
if (newMode === RETRIEVE_TYPE.multiWay) {
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, {
|
||||
provider: currentRerankProvider?.provider,
|
||||
model: currentRerankModel?.model,
|
||||
})
|
||||
}
|
||||
else {
|
||||
const hasSetModel = draft.single_retrieval_config?.model?.provider
|
||||
if (!hasSetModel) {
|
||||
draft.single_retrieval_config = {
|
||||
model: {
|
||||
provider: currentProvider?.provider || '',
|
||||
name: currentModel?.model || '',
|
||||
mode: currentModel?.model_properties?.mode as string,
|
||||
completion_params: {},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider])
|
||||
|
||||
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
const newMultipleRetrievalConfig = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, {
|
||||
provider: currentRerankProvider?.provider,
|
||||
model: currentRerankModel?.model,
|
||||
})
|
||||
draft.multiple_retrieval_config = newMultipleRetrievalConfig
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider])
|
||||
|
||||
const [selectedDatasetsLoaded, setSelectedDatasetsLoaded] = useState(false)
|
||||
// datasets
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
const inputs = inputRef.current
|
||||
const datasetIds = inputs.dataset_ids
|
||||
if (datasetIds?.length > 0) {
|
||||
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } as any })
|
||||
setSelectedDatasets(dataSetsWithDetail)
|
||||
}
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.dataset_ids = datasetIds
|
||||
})
|
||||
setInputs(newInputs)
|
||||
setSelectedDatasetsLoaded(true)
|
||||
})()
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
const inputs = inputRef.current
|
||||
let query_variable_selector: ValueSelector = inputs.query_variable_selector
|
||||
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
|
||||
query_variable_selector = [startNodeId, 'sys.query']
|
||||
|
||||
setInputs(produce(inputs, (draft) => {
|
||||
draft.query_variable_selector = query_variable_selector
|
||||
}))
|
||||
}, [])
|
||||
|
||||
const handleOnDatasetsChange = useCallback((newDatasets: DataSet[]) => {
|
||||
const {
|
||||
mixtureHighQualityAndEconomic,
|
||||
mixtureInternalAndExternal,
|
||||
inconsistentEmbeddingModel,
|
||||
allInternal,
|
||||
allExternal,
|
||||
} = getSelectedDatasetsMode(newDatasets)
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.dataset_ids = newDatasets.map(d => d.id)
|
||||
|
||||
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
const newMultipleRetrievalConfig = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, {
|
||||
provider: currentRerankProvider?.provider,
|
||||
model: currentRerankModel?.model,
|
||||
})
|
||||
draft.multiple_retrieval_config = newMultipleRetrievalConfig
|
||||
}
|
||||
})
|
||||
updateDatasetsDetail(newDatasets)
|
||||
setInputs(newInputs)
|
||||
setSelectedDatasets(newDatasets)
|
||||
|
||||
if (
|
||||
(allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel))
|
||||
|| mixtureInternalAndExternal
|
||||
|| allExternal
|
||||
)
|
||||
setRerankModelOpen(true)
|
||||
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider, updateDatasetsDetail])
|
||||
|
||||
const filterVar = useCallback((varPayload: Var) => {
|
||||
return varPayload.type === VarType.string
|
||||
}, [])
|
||||
|
||||
const handleMetadataFilterModeChange = useCallback((newMode: MetadataFilteringModeEnum) => {
|
||||
setInputs(produce(inputRef.current, (draft) => {
|
||||
draft.metadata_filtering_mode = newMode
|
||||
}))
|
||||
}, [setInputs])
|
||||
|
||||
const handleAddCondition = useCallback<HandleAddCondition>(({ name, type }) => {
|
||||
let operator: ComparisonOperator = ComparisonOperator.is
|
||||
|
||||
if (type === MetadataFilteringVariableType.number)
|
||||
operator = ComparisonOperator.equal
|
||||
|
||||
const newCondition = {
|
||||
id: uuid4(),
|
||||
name,
|
||||
comparison_operator: operator,
|
||||
}
|
||||
|
||||
const newInputs = produce(inputRef.current, (draft) => {
|
||||
if (draft.metadata_filtering_conditions) {
|
||||
draft.metadata_filtering_conditions.conditions.push(newCondition)
|
||||
}
|
||||
else {
|
||||
draft.metadata_filtering_conditions = {
|
||||
logical_operator: LogicalOperator.and,
|
||||
conditions: [newCondition],
|
||||
}
|
||||
}
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [setInputs])
|
||||
|
||||
const handleRemoveCondition = useCallback<HandleRemoveCondition>((id) => {
|
||||
const conditions = inputRef.current.metadata_filtering_conditions?.conditions || []
|
||||
const index = conditions.findIndex(c => c.id === id)
|
||||
const newInputs = produce(inputRef.current, (draft) => {
|
||||
if (index > -1)
|
||||
draft.metadata_filtering_conditions?.conditions.splice(index, 1)
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [setInputs])
|
||||
|
||||
const handleUpdateCondition = useCallback<HandleUpdateCondition>((id, newCondition) => {
|
||||
const conditions = inputRef.current.metadata_filtering_conditions?.conditions || []
|
||||
const index = conditions.findIndex(c => c.id === id)
|
||||
const newInputs = produce(inputRef.current, (draft) => {
|
||||
if (index > -1)
|
||||
draft.metadata_filtering_conditions!.conditions[index] = newCondition
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [setInputs])
|
||||
|
||||
const handleToggleConditionLogicalOperator = useCallback<HandleToggleConditionLogicalOperator>(() => {
|
||||
const oldLogicalOperator = inputRef.current.metadata_filtering_conditions?.logical_operator
|
||||
const newLogicalOperator = oldLogicalOperator === LogicalOperator.and ? LogicalOperator.or : LogicalOperator.and
|
||||
const newInputs = produce(inputRef.current, (draft) => {
|
||||
draft.metadata_filtering_conditions!.logical_operator = newLogicalOperator
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [setInputs])
|
||||
|
||||
const handleMetadataModelChange = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
|
||||
const newInputs = produce(inputRef.current, (draft) => {
|
||||
draft.metadata_model_config = {
|
||||
provider: model.provider,
|
||||
name: model.modelId,
|
||||
mode: model.mode || AppModeEnum.CHAT,
|
||||
completion_params: draft.metadata_model_config?.completion_params || { temperature: 0.7 },
|
||||
}
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [setInputs])
|
||||
|
||||
const handleMetadataCompletionParamsChange = useCallback((newParams: Record<string, any>) => {
|
||||
const newInputs = produce(inputRef.current, (draft) => {
|
||||
draft.metadata_model_config = {
|
||||
...draft.metadata_model_config!,
|
||||
completion_params: newParams,
|
||||
}
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [setInputs])
|
||||
|
||||
const filterStringVar = useCallback((varPayload: Var) => {
|
||||
return [VarType.string].includes(varPayload.type)
|
||||
}, [])
|
||||
|
||||
const {
|
||||
availableVars: availableStringVars,
|
||||
availableNodesWithParent: availableStringNodesWithParent,
|
||||
} = useAvailableVarList(id, {
|
||||
onlyLeafNodeVar: false,
|
||||
filterVar: filterStringVar,
|
||||
})
|
||||
|
||||
const filterNumberVar = useCallback((varPayload: Var) => {
|
||||
return [VarType.number].includes(varPayload.type)
|
||||
}, [])
|
||||
|
||||
const {
|
||||
availableVars: availableNumberVars,
|
||||
availableNodesWithParent: availableNumberNodesWithParent,
|
||||
} = useAvailableVarList(id, {
|
||||
onlyLeafNodeVar: false,
|
||||
filterVar: filterNumberVar,
|
||||
})
|
||||
|
||||
return {
|
||||
readOnly,
|
||||
inputs,
|
||||
handleQueryVarChange,
|
||||
filterVar,
|
||||
handleRetrievalModeChange,
|
||||
handleMultipleRetrievalConfigChange,
|
||||
handleModelChanged,
|
||||
handleCompletionParamsChange,
|
||||
selectedDatasets: selectedDatasets.filter(d => d.name),
|
||||
selectedDatasetsLoaded,
|
||||
handleOnDatasetsChange,
|
||||
rerankModelOpen,
|
||||
setRerankModelOpen,
|
||||
handleMetadataFilterModeChange,
|
||||
handleUpdateCondition,
|
||||
handleAddCondition,
|
||||
handleRemoveCondition,
|
||||
handleToggleConditionLogicalOperator,
|
||||
handleMetadataModelChange,
|
||||
handleMetadataCompletionParamsChange,
|
||||
availableStringVars,
|
||||
availableStringNodesWithParent,
|
||||
availableNumberVars,
|
||||
availableNumberNodesWithParent,
|
||||
}
|
||||
}
|
||||
|
||||
export default useConfig
|
||||
@@ -0,0 +1,63 @@
|
||||
import type { RefObject } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import type { InputVar, Variable } from '@/app/components/workflow/types'
|
||||
import { InputVarType } from '@/app/components/workflow/types'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import type { KnowledgeRetrievalNodeType } from './types'
|
||||
|
||||
const i18nPrefix = 'workflow.nodes.knowledgeRetrieval'
|
||||
|
||||
type Params = {
|
||||
id: string,
|
||||
payload: KnowledgeRetrievalNodeType
|
||||
runInputData: Record<string, any>
|
||||
runInputDataRef: RefObject<Record<string, any>>
|
||||
getInputVars: (textList: string[]) => InputVar[]
|
||||
setRunInputData: (data: Record<string, any>) => void
|
||||
toVarInputs: (variables: Variable[]) => InputVar[]
|
||||
}
|
||||
const useSingleRunFormParams = ({
|
||||
payload,
|
||||
runInputData,
|
||||
setRunInputData,
|
||||
}: Params) => {
|
||||
const { t } = useTranslation()
|
||||
const query = runInputData.query
|
||||
const setQuery = useCallback((newQuery: string) => {
|
||||
setRunInputData({
|
||||
...runInputData,
|
||||
query: newQuery,
|
||||
})
|
||||
}, [runInputData, setRunInputData])
|
||||
|
||||
const forms = useMemo(() => {
|
||||
return [
|
||||
{
|
||||
inputs: [{
|
||||
label: t(`${i18nPrefix}.queryVariable`)!,
|
||||
variable: 'query',
|
||||
type: InputVarType.paragraph,
|
||||
required: true,
|
||||
}],
|
||||
values: { query },
|
||||
onChange: (keyValue: Record<string, any>) => setQuery(keyValue.query),
|
||||
},
|
||||
]
|
||||
}, [query, setQuery, t])
|
||||
|
||||
const getDependentVars = () => {
|
||||
return [payload.query_variable_selector]
|
||||
}
|
||||
const getDependentVar = (variable: string) => {
|
||||
if(variable === 'query')
|
||||
return payload.query_variable_selector
|
||||
}
|
||||
|
||||
return {
|
||||
forms,
|
||||
getDependentVars,
|
||||
getDependentVar,
|
||||
}
|
||||
}
|
||||
|
||||
export default useSingleRunFormParams
|
||||
@@ -0,0 +1,288 @@
|
||||
import {
|
||||
uniq,
|
||||
xorBy,
|
||||
} from 'lodash-es'
|
||||
import type { MultipleRetrievalConfig } from './types'
|
||||
import type {
|
||||
DataSet,
|
||||
SelectedDatasetsMode,
|
||||
} from '@/models/datasets'
|
||||
import {
|
||||
DEFAULT_WEIGHTED_SCORE,
|
||||
RerankingModeEnum,
|
||||
WeightedScoreEnum,
|
||||
} from '@/models/datasets'
|
||||
import { RETRIEVE_METHOD } from '@/types/app'
|
||||
import { DATASET_DEFAULT } from '@/config'
|
||||
|
||||
export const checkNodeValid = () => {
|
||||
return true
|
||||
}
|
||||
|
||||
export const getSelectedDatasetsMode = (datasets: DataSet[] = []) => {
|
||||
if (datasets === null)
|
||||
datasets = []
|
||||
let allHighQuality = true
|
||||
let allHighQualityVectorSearch = true
|
||||
let allHighQualityFullTextSearch = true
|
||||
let allEconomic = true
|
||||
let mixtureHighQualityAndEconomic = true
|
||||
let allExternal = true
|
||||
let allInternal = true
|
||||
let mixtureInternalAndExternal = true
|
||||
let inconsistentEmbeddingModel = false
|
||||
if (!datasets.length) {
|
||||
allHighQuality = false
|
||||
allHighQualityVectorSearch = false
|
||||
allHighQualityFullTextSearch = false
|
||||
allEconomic = false
|
||||
mixtureHighQualityAndEconomic = false
|
||||
allExternal = false
|
||||
allInternal = false
|
||||
mixtureInternalAndExternal = false
|
||||
}
|
||||
datasets.forEach((dataset) => {
|
||||
if (dataset.indexing_technique === 'economy') {
|
||||
allHighQuality = false
|
||||
allHighQualityVectorSearch = false
|
||||
allHighQualityFullTextSearch = false
|
||||
}
|
||||
if (dataset.indexing_technique === 'high_quality') {
|
||||
allEconomic = false
|
||||
|
||||
if (dataset.retrieval_model_dict.search_method !== RETRIEVE_METHOD.semantic)
|
||||
allHighQualityVectorSearch = false
|
||||
|
||||
if (dataset.retrieval_model_dict.search_method !== RETRIEVE_METHOD.fullText)
|
||||
allHighQualityFullTextSearch = false
|
||||
}
|
||||
if (dataset.provider !== 'external') {
|
||||
allExternal = false
|
||||
}
|
||||
else {
|
||||
allInternal = false
|
||||
allHighQuality = false
|
||||
allHighQualityVectorSearch = false
|
||||
allHighQualityFullTextSearch = false
|
||||
mixtureHighQualityAndEconomic = false
|
||||
}
|
||||
})
|
||||
|
||||
if (allExternal || allInternal)
|
||||
mixtureInternalAndExternal = false
|
||||
|
||||
if (allHighQuality || allEconomic)
|
||||
mixtureHighQualityAndEconomic = false
|
||||
|
||||
if (allHighQuality)
|
||||
inconsistentEmbeddingModel = uniq(datasets.map(item => item.embedding_model)).length > 1
|
||||
|
||||
return {
|
||||
allHighQuality,
|
||||
allHighQualityVectorSearch,
|
||||
allHighQualityFullTextSearch,
|
||||
allEconomic,
|
||||
mixtureHighQualityAndEconomic,
|
||||
allInternal,
|
||||
allExternal,
|
||||
mixtureInternalAndExternal,
|
||||
inconsistentEmbeddingModel,
|
||||
} as SelectedDatasetsMode
|
||||
}
|
||||
|
||||
export const getMultipleRetrievalConfig = (
|
||||
multipleRetrievalConfig: MultipleRetrievalConfig,
|
||||
selectedDatasets: DataSet[],
|
||||
originalDatasets: DataSet[],
|
||||
fallbackRerankModel?: { provider?: string; model?: string }, // fallback rerank model
|
||||
) => {
|
||||
// Check if the selected datasets are different from the original datasets
|
||||
const isDatasetsChanged = xorBy(selectedDatasets, originalDatasets, 'id').length > 0
|
||||
// Check if the rerank model is valid
|
||||
const isFallbackRerankModelValid = !!(fallbackRerankModel?.provider && fallbackRerankModel?.model)
|
||||
|
||||
const {
|
||||
allHighQuality,
|
||||
allHighQualityVectorSearch,
|
||||
allHighQualityFullTextSearch,
|
||||
allEconomic,
|
||||
mixtureHighQualityAndEconomic,
|
||||
allInternal,
|
||||
allExternal,
|
||||
mixtureInternalAndExternal,
|
||||
inconsistentEmbeddingModel,
|
||||
} = getSelectedDatasetsMode(selectedDatasets)
|
||||
|
||||
const {
|
||||
top_k = DATASET_DEFAULT.top_k,
|
||||
score_threshold,
|
||||
reranking_mode,
|
||||
reranking_model,
|
||||
weights,
|
||||
reranking_enable,
|
||||
} = multipleRetrievalConfig || { top_k: DATASET_DEFAULT.top_k }
|
||||
|
||||
const result = {
|
||||
top_k,
|
||||
score_threshold,
|
||||
reranking_mode,
|
||||
reranking_model,
|
||||
weights,
|
||||
reranking_enable,
|
||||
}
|
||||
|
||||
const setDefaultWeights = () => {
|
||||
result.weights = {
|
||||
weight_type: WeightedScoreEnum.Customized,
|
||||
vector_setting: {
|
||||
vector_weight: allHighQualityVectorSearch
|
||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic
|
||||
// eslint-disable-next-line sonarjs/no-nested-conditional
|
||||
: allHighQualityFullTextSearch
|
||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic
|
||||
: DEFAULT_WEIGHTED_SCORE.other.semantic,
|
||||
embedding_provider_name: selectedDatasets[0].embedding_model_provider,
|
||||
embedding_model_name: selectedDatasets[0].embedding_model,
|
||||
},
|
||||
keyword_setting: {
|
||||
keyword_weight: allHighQualityVectorSearch
|
||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword
|
||||
// eslint-disable-next-line sonarjs/no-nested-conditional
|
||||
: allHighQualityFullTextSearch
|
||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword
|
||||
: DEFAULT_WEIGHTED_SCORE.other.keyword,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* In this case, user can manually toggle reranking
|
||||
* So should keep the reranking_enable value
|
||||
* But the default reranking_model should be set
|
||||
*/
|
||||
if ((allEconomic && allInternal) || allExternal) {
|
||||
result.reranking_mode = RerankingModeEnum.RerankingModel
|
||||
// Need to check if the reranking model should be set to default when first time initialized
|
||||
if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) {
|
||||
result.reranking_model = {
|
||||
provider: fallbackRerankModel.provider || '',
|
||||
model: fallbackRerankModel.model || '',
|
||||
}
|
||||
}
|
||||
result.reranking_enable = reranking_enable
|
||||
}
|
||||
|
||||
/**
|
||||
* In this case, reranking_enable must be true
|
||||
* And if rerank model is not set, should set the default rerank model
|
||||
*/
|
||||
if (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || mixtureInternalAndExternal) {
|
||||
result.reranking_mode = RerankingModeEnum.RerankingModel
|
||||
// Need to check if the reranking model should be set to default when first time initialized
|
||||
if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) {
|
||||
result.reranking_model = {
|
||||
provider: fallbackRerankModel.provider || '',
|
||||
model: fallbackRerankModel.model || '',
|
||||
}
|
||||
}
|
||||
result.reranking_enable = true
|
||||
}
|
||||
|
||||
/**
|
||||
* In this case, user can choose to use weighted score or rerank model
|
||||
* But if the reranking_mode is not initialized, should set the default rerank model and reranking_enable to true
|
||||
* and set reranking_mode to reranking_model
|
||||
*/
|
||||
if (allHighQuality && !inconsistentEmbeddingModel && allInternal) {
|
||||
// If not initialized, check if the default rerank model is valid
|
||||
if (!reranking_mode) {
|
||||
if (isFallbackRerankModelValid) {
|
||||
result.reranking_mode = RerankingModeEnum.RerankingModel
|
||||
result.reranking_enable = true
|
||||
|
||||
result.reranking_model = {
|
||||
provider: fallbackRerankModel.provider || '',
|
||||
model: fallbackRerankModel.model || '',
|
||||
}
|
||||
}
|
||||
else {
|
||||
result.reranking_mode = RerankingModeEnum.WeightedScore
|
||||
result.reranking_enable = false
|
||||
setDefaultWeights()
|
||||
}
|
||||
}
|
||||
|
||||
// After initialization, if datasets has no change, make sure the config has correct value
|
||||
if (reranking_mode === RerankingModeEnum.WeightedScore) {
|
||||
result.reranking_enable = false
|
||||
if (!weights)
|
||||
setDefaultWeights()
|
||||
}
|
||||
if (reranking_mode === RerankingModeEnum.RerankingModel) {
|
||||
if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) {
|
||||
result.reranking_model = {
|
||||
provider: fallbackRerankModel.provider || '',
|
||||
model: fallbackRerankModel.model || '',
|
||||
}
|
||||
}
|
||||
result.reranking_enable = true
|
||||
}
|
||||
|
||||
// Need to check if reranking_mode should be set to reranking_model when datasets changed
|
||||
if (reranking_mode === RerankingModeEnum.WeightedScore && weights && isDatasetsChanged) {
|
||||
if ((result.reranking_model?.provider && result.reranking_model?.model) || isFallbackRerankModelValid) {
|
||||
result.reranking_mode = RerankingModeEnum.RerankingModel
|
||||
result.reranking_enable = true
|
||||
|
||||
// eslint-disable-next-line sonarjs/nested-control-flow
|
||||
if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) {
|
||||
result.reranking_model = {
|
||||
provider: fallbackRerankModel.provider || '',
|
||||
model: fallbackRerankModel.model || '',
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
setDefaultWeights()
|
||||
}
|
||||
}
|
||||
// Need to switch to weighted score when reranking model is not valid and datasets changed
|
||||
if (
|
||||
reranking_mode === RerankingModeEnum.RerankingModel
|
||||
&& (!result.reranking_model?.provider || !result.reranking_model?.model)
|
||||
&& !isFallbackRerankModelValid
|
||||
&& isDatasetsChanged
|
||||
) {
|
||||
result.reranking_mode = RerankingModeEnum.WeightedScore
|
||||
result.reranking_enable = false
|
||||
setDefaultWeights()
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
export const checkoutRerankModelConfiguredInRetrievalSettings = (
|
||||
datasets: DataSet[],
|
||||
multipleRetrievalConfig?: MultipleRetrievalConfig,
|
||||
) => {
|
||||
if (!multipleRetrievalConfig)
|
||||
return true
|
||||
|
||||
const {
|
||||
allEconomic,
|
||||
allExternal,
|
||||
allInternal,
|
||||
} = getSelectedDatasetsMode(datasets)
|
||||
|
||||
const {
|
||||
reranking_enable,
|
||||
reranking_mode,
|
||||
reranking_model,
|
||||
} = multipleRetrievalConfig
|
||||
|
||||
if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model))
|
||||
return ((allEconomic && allInternal) || allExternal) && !reranking_enable
|
||||
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user