Skip to content

fix: retrieval setting validate #10454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,16 @@ const DatasetConfig: FC = () => {

const {
currentModel: currentRerankModel,
currentProvider: currentRerankProvider,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)

const onRemove = (id: string) => {
const filteredDataSets = dataSet.filter(item => item.id !== id)
setDataSet(filteredDataSets)
const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, !!currentRerankModel)
const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, {
provider: currentRerankProvider?.provider,
model: currentRerankModel?.model,
})
setDatasetConfigs({
...(datasetConfigs as any),
...retrievalConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ const ConfigContent: FC<Props> = ({
return false

return datasetConfigs.reranking_enable
}, [canManuallyToggleRerank, datasetConfigs.reranking_enable])
}, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid])

const handleDisabledSwitchClick = useCallback(() => {
if (!currentRerankModel && !showRerankModel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ const ParamsConfig = ({
const {
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelValid,
currentProvider: rerankDefaultProvider,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)

const isValid = () => {
Expand Down Expand Up @@ -91,7 +92,10 @@ const ParamsConfig = ({
reranking_mode: restConfigs.reranking_mode,
weights: restConfigs.weights,
reranking_enable: restConfigs.reranking_enable,
}, selectedDatasets, selectedDatasets, !!isRerankDefaultModelValid)
}, selectedDatasets, selectedDatasets, {
provider: rerankDefaultProvider?.provider,
model: isRerankDefaultModelValid?.model,
})

setTempDataSetConfigs({
...retrievalConfig,
Expand Down
11 changes: 9 additions & 2 deletions web/app/components/app/configuration/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ const Configuration: FC = () => {
const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false)
const {
currentModel: currentRerankModel,
currentProvider: currentRerankProvider,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const handleSelect = (data: DataSet[]) => {
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
Expand Down Expand Up @@ -279,7 +280,10 @@ const Configuration: FC = () => {
reranking_mode: restConfigs.reranking_mode,
weights: restConfigs.weights,
reranking_enable: restConfigs.reranking_enable,
}, newDatasets, dataSets, !!currentRerankModel)
}, newDatasets, dataSets, {
provider: currentRerankProvider?.provider,
model: currentRerankModel?.model,
})

setDatasetConfigs({
...retrievalConfig,
Expand Down Expand Up @@ -620,7 +624,10 @@ const Configuration: FC = () => {

syncToPublishedConfig(config)
setPublishedConfig(config)
const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, !!currentRerankModel)
const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, {
provider: currentRerankProvider?.provider,
model: currentRerankModel?.model,
})
setDatasetConfigs({
retrieval_model: RETRIEVE_TYPE.multiWay,
...modelConfig.dataset_configs,
Expand Down
13 changes: 9 additions & 4 deletions web/app/components/workflow/nodes/knowledge-retrieval/default.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { BlockEnum } from '../../types'
import type { NodeDefault } from '../../types'
import type { KnowledgeRetrievalNodeType } from './types'
import { RerankingModeEnum } from '@/models/datasets'
import { checkoutRerankModelConfigedInRetrievalSettings } from './utils'
import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants'
import { DATASET_DEFAULT } from '@/config'
import { RETRIEVE_TYPE } from '@/types/app'
Expand Down Expand Up @@ -36,12 +36,17 @@ const nodeDefault: NodeDefault<KnowledgeRetrievalNodeType> = {
if (!errorMessages && (!payload.dataset_ids || payload.dataset_ids.length === 0))
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.knowledge`) })

if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.multiWay && payload.multiple_retrieval_config?.reranking_mode === RerankingModeEnum.RerankingModel && !payload.multiple_retrieval_config?.reranking_model?.provider && payload.multiple_retrieval_config?.reranking_enable)
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })

if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.oneWay && !payload.single_retrieval_config?.model?.provider)
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t('common.modelProvider.systemReasoningModel.key') })

const { _datasets, multiple_retrieval_config, retrieval_mode } = payload
if (retrieval_mode === RETRIEVE_TYPE.multiWay) {
const checked = checkoutRerankModelConfigedInRetrievalSettings(_datasets || [], multiple_retrieval_config)

if (!errorMessages && !checked)
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })
}

return {
isValid: !errorMessages,
errorMessage: errorMessages,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { CommonNodeType, ModelConfig, ValueSelector } from '@/app/components/workflow/types'
import type { RETRIEVE_TYPE } from '@/types/app'
import type {
DataSet,
RerankingModeEnum,
} from '@/models/datasets'

Expand Down Expand Up @@ -35,4 +36,5 @@ export type KnowledgeRetrievalNodeType = CommonNodeType & {
retrieval_mode: RETRIEVE_TYPE
multiple_retrieval_config?: MultipleRetrievalConfig
single_retrieval_config?: SingleRetrievalConfig
_datasets?: DataSet[]
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {

const {
currentModel: currentRerankModel,
currentProvider: currentRerankProvider,
} = useCurrentProviderAndModel(
rerankModelList,
rerankDefaultModel
Expand Down Expand Up @@ -163,7 +164,10 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
draft.retrieval_mode = newMode
if (newMode === RETRIEVE_TYPE.multiWay) {
const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel)
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, {
provider: currentRerankProvider?.provider,
model: currentRerankModel?.model,
})
}
else {
const hasSetModel = draft.single_retrieval_config?.model?.provider
Expand All @@ -180,14 +184,17 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
}
})
setInputs(newInputs)
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel])
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider])

const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
const newInputs = produce(inputs, (draft) => {
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel)
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, {
provider: currentRerankProvider?.provider,
model: currentRerankModel?.model,
})
})
setInputs(newInputs)
}, [inputs, setInputs, selectedDatasets, currentRerankModel])
}, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider])

// datasets
useEffect(() => {
Expand All @@ -200,6 +207,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
}
const newInputs = produce(inputs, (draft) => {
draft.dataset_ids = datasetIds
draft._datasets = selectedDatasets
})
setInputs(newInputs)
})()
Expand Down Expand Up @@ -228,10 +236,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
} = getSelectedDatasetsMode(newDatasets)
const newInputs = produce(inputs, (draft) => {
draft.dataset_ids = newDatasets.map(d => d.id)
draft._datasets = newDatasets

if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel)
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, {
provider: currentRerankProvider?.provider,
model: currentRerankModel?.model,
})
}
})
setInputs(newInputs)
Expand All @@ -243,7 +255,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|| allExternal
)
setRerankModelOpen(true)
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel])
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider])

const filterVar = useCallback((varPayload: Var) => {
return varPayload.type === VarType.string
Expand Down
113 changes: 80 additions & 33 deletions web/app/components/workflow/nodes/knowledge-retrieval/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ export const getMultipleRetrievalConfig = (
multipleRetrievalConfig: MultipleRetrievalConfig,
selectedDatasets: DataSet[],
originalDatasets: DataSet[],
isValidRerankModel?: boolean,
validRerankModel?: { provider?: string; model?: string },
) => {
const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0
const rerankModelIsValid = validRerankModel?.provider && validRerankModel?.model

const {
allHighQuality,
Expand Down Expand Up @@ -128,18 +129,10 @@ export const getMultipleRetrievalConfig = (
reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true,
}

if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal)
result.reranking_mode = RerankingModeEnum.RerankingModel

if (allHighQuality && !inconsistentEmbeddingModel && reranking_mode === undefined && allInternal)
result.reranking_mode = RerankingModeEnum.WeightedScore

if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) {
if (!isValidRerankModel)
result.reranking_mode = RerankingModeEnum.WeightedScore
else
result.reranking_mode = RerankingModeEnum.RerankingModel
if (!rerankModelIsValid)
result.reranking_model = undefined

const setDefaultWeights = () => {
result.weights = {
vector_setting: {
vector_weight: allHighQualityVectorSearch
Expand All @@ -160,31 +153,85 @@ export const getMultipleRetrievalConfig = (
}
}

if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && weights) {
if (!isValidRerankModel)
result.reranking_mode = RerankingModeEnum.WeightedScore
else
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) {
result.reranking_mode = RerankingModeEnum.RerankingModel

if (rerankModelIsValid) {
result.reranking_mode = RerankingModeEnum.RerankingModel
result.reranking_model = {
provider: validRerankModel?.provider || '',
model: validRerankModel?.model || '',
}
}
else {
result.reranking_model = undefined
}
}

result.weights = {
vector_setting: {
vector_weight: allHighQualityVectorSearch
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic
: allHighQualityFullTextSearch
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic
: DEFAULT_WEIGHTED_SCORE.other.semantic,
embedding_provider_name: selectedDatasets[0].embedding_model_provider,
embedding_model_name: selectedDatasets[0].embedding_model,
},
keyword_setting: {
keyword_weight: allHighQualityVectorSearch
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword
: allHighQualityFullTextSearch
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword
: DEFAULT_WEIGHTED_SCORE.other.keyword,
},
if (allHighQuality && !inconsistentEmbeddingModel && allInternal) {
if (!reranking_mode) {
if (validRerankModel?.provider && validRerankModel?.model) {
result.reranking_mode = RerankingModeEnum.RerankingModel
result.reranking_model = {
provider: validRerankModel.provider,
model: validRerankModel.model,
}
}
else {
result.reranking_mode = RerankingModeEnum.WeightedScore
setDefaultWeights()
}
}

if (reranking_mode === RerankingModeEnum.WeightedScore && !weights)
setDefaultWeights()

if (reranking_mode === RerankingModeEnum.WeightedScore && weights && shouldSetWeightDefaultValue) {
if (rerankModelIsValid) {
result.reranking_mode = RerankingModeEnum.RerankingModel
result.reranking_model = {
provider: validRerankModel.provider || '',
model: validRerankModel.model || '',
}
}
else {
setDefaultWeights()
}
}

if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) {
result.reranking_mode = RerankingModeEnum.WeightedScore
setDefaultWeights()
}
}

return result
}

export const checkoutRerankModelConfigedInRetrievalSettings = (
datasets: DataSet[],
multipleRetrievalConfig?: MultipleRetrievalConfig,
) => {
if (!multipleRetrievalConfig)
return true

const {
allEconomic,
allExternal,
} = getSelectedDatasetsMode(datasets)

const {
reranking_enable,
reranking_mode,
reranking_model,
} = multipleRetrievalConfig

if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model)) {
if ((allEconomic || allExternal) && !reranking_enable)
return true

return false
}

return true
}
Loading