Skip to content

Commit e4d1757

Browse files
authored
fix: retrieval setting validate (#10454)
1 parent 16b9665 commit e4d1757

File tree

8 files changed

+129
-48
lines changed

8 files changed

+129
-48
lines changed

web/app/components/app/configuration/dataset-config/index.tsx

+5-1
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,16 @@ const DatasetConfig: FC = () => {
4747

4848
const {
4949
currentModel: currentRerankModel,
50+
currentProvider: currentRerankProvider,
5051
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
5152

5253
const onRemove = (id: string) => {
5354
const filteredDataSets = dataSet.filter(item => item.id !== id)
5455
setDataSet(filteredDataSets)
55-
const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, !!currentRerankModel)
56+
const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, {
57+
provider: currentRerankProvider?.provider,
58+
model: currentRerankModel?.model,
59+
})
5660
setDatasetConfigs({
5761
...(datasetConfigs as any),
5862
...retrievalConfig,

web/app/components/app/configuration/dataset-config/params-config/config-content.tsx

+1-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ const ConfigContent: FC<Props> = ({
172172
return false
173173

174174
return datasetConfigs.reranking_enable
175-
}, [canManuallyToggleRerank, datasetConfigs.reranking_enable])
175+
}, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid])
176176

177177
const handleDisabledSwitchClick = useCallback(() => {
178178
if (!currentRerankModel && !showRerankModel)

web/app/components/app/configuration/dataset-config/params-config/index.tsx

+5-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ const ParamsConfig = ({
4343
const {
4444
defaultModel: rerankDefaultModel,
4545
currentModel: isRerankDefaultModelValid,
46+
currentProvider: rerankDefaultProvider,
4647
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
4748

4849
const isValid = () => {
@@ -91,7 +92,10 @@ const ParamsConfig = ({
9192
reranking_mode: restConfigs.reranking_mode,
9293
weights: restConfigs.weights,
9394
reranking_enable: restConfigs.reranking_enable,
94-
}, selectedDatasets, selectedDatasets, !!isRerankDefaultModelValid)
95+
}, selectedDatasets, selectedDatasets, {
96+
provider: rerankDefaultProvider?.provider,
97+
model: isRerankDefaultModelValid?.model,
98+
})
9599

96100
setTempDataSetConfigs({
97101
...retrievalConfig,

web/app/components/app/configuration/index.tsx

+9-2
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ const Configuration: FC = () => {
226226
const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false)
227227
const {
228228
currentModel: currentRerankModel,
229+
currentProvider: currentRerankProvider,
229230
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
230231
const handleSelect = (data: DataSet[]) => {
231232
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
@@ -279,7 +280,10 @@ const Configuration: FC = () => {
279280
reranking_mode: restConfigs.reranking_mode,
280281
weights: restConfigs.weights,
281282
reranking_enable: restConfigs.reranking_enable,
282-
}, newDatasets, dataSets, !!currentRerankModel)
283+
}, newDatasets, dataSets, {
284+
provider: currentRerankProvider?.provider,
285+
model: currentRerankModel?.model,
286+
})
283287

284288
setDatasetConfigs({
285289
...retrievalConfig,
@@ -620,7 +624,10 @@ const Configuration: FC = () => {
620624

621625
syncToPublishedConfig(config)
622626
setPublishedConfig(config)
623-
const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, !!currentRerankModel)
627+
const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, {
628+
provider: currentRerankProvider?.provider,
629+
model: currentRerankModel?.model,
630+
})
624631
setDatasetConfigs({
625632
retrieval_model: RETRIEVE_TYPE.multiWay,
626633
...modelConfig.dataset_configs,

web/app/components/workflow/nodes/knowledge-retrieval/default.ts

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { BlockEnum } from '../../types'
22
import type { NodeDefault } from '../../types'
33
import type { KnowledgeRetrievalNodeType } from './types'
4-
import { RerankingModeEnum } from '@/models/datasets'
4+
import { checkoutRerankModelConfigedInRetrievalSettings } from './utils'
55
import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants'
66
import { DATASET_DEFAULT } from '@/config'
77
import { RETRIEVE_TYPE } from '@/types/app'
@@ -36,12 +36,17 @@ const nodeDefault: NodeDefault<KnowledgeRetrievalNodeType> = {
3636
if (!errorMessages && (!payload.dataset_ids || payload.dataset_ids.length === 0))
3737
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.knowledge`) })
3838

39-
if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.multiWay && payload.multiple_retrieval_config?.reranking_mode === RerankingModeEnum.RerankingModel && !payload.multiple_retrieval_config?.reranking_model?.provider && payload.multiple_retrieval_config?.reranking_enable)
40-
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })
41-
4239
if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.oneWay && !payload.single_retrieval_config?.model?.provider)
4340
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t('common.modelProvider.systemReasoningModel.key') })
4441

42+
const { _datasets, multiple_retrieval_config, retrieval_mode } = payload
43+
if (retrieval_mode === RETRIEVE_TYPE.multiWay) {
44+
const checked = checkoutRerankModelConfigedInRetrievalSettings(_datasets || [], multiple_retrieval_config)
45+
46+
if (!errorMessages && !checked)
47+
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })
48+
}
49+
4550
return {
4651
isValid: !errorMessages,
4752
errorMessage: errorMessages,

web/app/components/workflow/nodes/knowledge-retrieval/types.ts

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import type { CommonNodeType, ModelConfig, ValueSelector } from '@/app/components/workflow/types'
22
import type { RETRIEVE_TYPE } from '@/types/app'
33
import type {
4+
DataSet,
45
RerankingModeEnum,
56
} from '@/models/datasets'
67

@@ -35,4 +36,5 @@ export type KnowledgeRetrievalNodeType = CommonNodeType & {
3536
retrieval_mode: RETRIEVE_TYPE
3637
multiple_retrieval_config?: MultipleRetrievalConfig
3738
single_retrieval_config?: SingleRetrievalConfig
39+
_datasets?: DataSet[]
3840
}

web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts

+18-6
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
6767

6868
const {
6969
currentModel: currentRerankModel,
70+
currentProvider: currentRerankProvider,
7071
} = useCurrentProviderAndModel(
7172
rerankModelList,
7273
rerankDefaultModel
@@ -163,7 +164,10 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
163164
draft.retrieval_mode = newMode
164165
if (newMode === RETRIEVE_TYPE.multiWay) {
165166
const multipleRetrievalConfig = draft.multiple_retrieval_config
166-
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel)
167+
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, {
168+
provider: currentRerankProvider?.provider,
169+
model: currentRerankModel?.model,
170+
})
167171
}
168172
else {
169173
const hasSetModel = draft.single_retrieval_config?.model?.provider
@@ -180,14 +184,17 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
180184
}
181185
})
182186
setInputs(newInputs)
183-
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel])
187+
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider])
184188

185189
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
186190
const newInputs = produce(inputs, (draft) => {
187-
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel)
191+
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, {
192+
provider: currentRerankProvider?.provider,
193+
model: currentRerankModel?.model,
194+
})
188195
})
189196
setInputs(newInputs)
190-
}, [inputs, setInputs, selectedDatasets, currentRerankModel])
197+
}, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider])
191198

192199
// datasets
193200
useEffect(() => {
@@ -200,6 +207,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
200207
}
201208
const newInputs = produce(inputs, (draft) => {
202209
draft.dataset_ids = datasetIds
210+
draft._datasets = selectedDatasets
203211
})
204212
setInputs(newInputs)
205213
})()
@@ -228,10 +236,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
228236
} = getSelectedDatasetsMode(newDatasets)
229237
const newInputs = produce(inputs, (draft) => {
230238
draft.dataset_ids = newDatasets.map(d => d.id)
239+
draft._datasets = newDatasets
231240

232241
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
233242
const multipleRetrievalConfig = draft.multiple_retrieval_config
234-
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel)
243+
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, {
244+
provider: currentRerankProvider?.provider,
245+
model: currentRerankModel?.model,
246+
})
235247
}
236248
})
237249
setInputs(newInputs)
@@ -243,7 +255,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
243255
|| allExternal
244256
)
245257
setRerankModelOpen(true)
246-
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel])
258+
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider])
247259

248260
const filterVar = useCallback((varPayload: Var) => {
249261
return varPayload.type === VarType.string

web/app/components/workflow/nodes/knowledge-retrieval/utils.ts

+80-33
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,10 @@ export const getMultipleRetrievalConfig = (
9494
multipleRetrievalConfig: MultipleRetrievalConfig,
9595
selectedDatasets: DataSet[],
9696
originalDatasets: DataSet[],
97-
isValidRerankModel?: boolean,
97+
validRerankModel?: { provider?: string; model?: string },
9898
) => {
9999
const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0
100+
const rerankModelIsValid = validRerankModel?.provider && validRerankModel?.model
100101

101102
const {
102103
allHighQuality,
@@ -128,18 +129,10 @@ export const getMultipleRetrievalConfig = (
128129
reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true,
129130
}
130131

131-
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal)
132-
result.reranking_mode = RerankingModeEnum.RerankingModel
133-
134-
if (allHighQuality && !inconsistentEmbeddingModel && reranking_mode === undefined && allInternal)
135-
result.reranking_mode = RerankingModeEnum.WeightedScore
136-
137-
if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) {
138-
if (!isValidRerankModel)
139-
result.reranking_mode = RerankingModeEnum.WeightedScore
140-
else
141-
result.reranking_mode = RerankingModeEnum.RerankingModel
132+
if (!rerankModelIsValid)
133+
result.reranking_model = undefined
142134

135+
const setDefaultWeights = () => {
143136
result.weights = {
144137
vector_setting: {
145138
vector_weight: allHighQualityVectorSearch
@@ -160,31 +153,85 @@ export const getMultipleRetrievalConfig = (
160153
}
161154
}
162155

163-
if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && weights) {
164-
if (!isValidRerankModel)
165-
result.reranking_mode = RerankingModeEnum.WeightedScore
166-
else
156+
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) {
157+
result.reranking_mode = RerankingModeEnum.RerankingModel
158+
159+
if (rerankModelIsValid) {
167160
result.reranking_mode = RerankingModeEnum.RerankingModel
161+
result.reranking_model = {
162+
provider: validRerankModel?.provider || '',
163+
model: validRerankModel?.model || '',
164+
}
165+
}
166+
else {
167+
result.reranking_model = undefined
168+
}
169+
}
168170

169-
result.weights = {
170-
vector_setting: {
171-
vector_weight: allHighQualityVectorSearch
172-
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic
173-
: allHighQualityFullTextSearch
174-
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic
175-
: DEFAULT_WEIGHTED_SCORE.other.semantic,
176-
embedding_provider_name: selectedDatasets[0].embedding_model_provider,
177-
embedding_model_name: selectedDatasets[0].embedding_model,
178-
},
179-
keyword_setting: {
180-
keyword_weight: allHighQualityVectorSearch
181-
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword
182-
: allHighQualityFullTextSearch
183-
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword
184-
: DEFAULT_WEIGHTED_SCORE.other.keyword,
185-
},
171+
if (allHighQuality && !inconsistentEmbeddingModel && allInternal) {
172+
if (!reranking_mode) {
173+
if (validRerankModel?.provider && validRerankModel?.model) {
174+
result.reranking_mode = RerankingModeEnum.RerankingModel
175+
result.reranking_model = {
176+
provider: validRerankModel.provider,
177+
model: validRerankModel.model,
178+
}
179+
}
180+
else {
181+
result.reranking_mode = RerankingModeEnum.WeightedScore
182+
setDefaultWeights()
183+
}
184+
}
185+
186+
if (reranking_mode === RerankingModeEnum.WeightedScore && !weights)
187+
setDefaultWeights()
188+
189+
if (reranking_mode === RerankingModeEnum.WeightedScore && weights && shouldSetWeightDefaultValue) {
190+
if (rerankModelIsValid) {
191+
result.reranking_mode = RerankingModeEnum.RerankingModel
192+
result.reranking_model = {
193+
provider: validRerankModel.provider || '',
194+
model: validRerankModel.model || '',
195+
}
196+
}
197+
else {
198+
setDefaultWeights()
199+
}
200+
}
201+
202+
if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) {
203+
result.reranking_mode = RerankingModeEnum.WeightedScore
204+
setDefaultWeights()
186205
}
187206
}
188207

189208
return result
190209
}
210+
211+
export const checkoutRerankModelConfigedInRetrievalSettings = (
212+
datasets: DataSet[],
213+
multipleRetrievalConfig?: MultipleRetrievalConfig,
214+
) => {
215+
if (!multipleRetrievalConfig)
216+
return true
217+
218+
const {
219+
allEconomic,
220+
allExternal,
221+
} = getSelectedDatasetsMode(datasets)
222+
223+
const {
224+
reranking_enable,
225+
reranking_mode,
226+
reranking_model,
227+
} = multipleRetrievalConfig
228+
229+
if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model)) {
230+
if ((allEconomic || allExternal) && !reranking_enable)
231+
return true
232+
233+
return false
234+
}
235+
236+
return true
237+
}

0 commit comments

Comments
 (0)