Skip to content

Commit 72fd873

Browse files
committed
✨ feat: support user config model
1 parent d865ca1 commit 72fd873

File tree

11 files changed

+149
-43
lines changed

11 files changed

+149
-43
lines changed

src/app/settings/llm/Ollama/index.tsx

-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ const OllamaProvider = memo(() => {
2222
}}
2323
provider={ModelProvider.Ollama}
2424
showApiKey={false}
25-
showCustomModelName
2625
showEndpoint
2726
title={
2827
<Ollama.Combine color={theme.isDarkMode ? theme.colorText : theme.colorPrimary} size={24} />

src/app/settings/llm/components/CustomModelList/index.tsx

+15-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ import { memo } from 'react';
55

66
import { filterEnabledModels } from '@/config/modelProviders';
77
import { useGlobalStore } from '@/store/global';
8-
import { modelConfigSelectors } from '@/store/global/selectors';
8+
import { modelConfigSelectors, modelProviderSelectors } from '@/store/global/selectors';
9+
import { GlobalLLMProviderKey } from '@/types/settings';
910

1011
import { OptionRender } from './Option';
1112

@@ -18,22 +19,31 @@ const popup = css`
1819
`;
1920

2021
interface CustomModelSelectProps {
22+
onChange?: (value: string[]) => void;
2123
placeholder?: string;
2224
provider: string;
25+
value?: string[];
2326
}
2427

25-
const CustomModelSelect = memo<CustomModelSelectProps>(({ provider, placeholder }) => {
28+
const CustomModelSelect = memo<CustomModelSelectProps>(({ provider, placeholder, onChange }) => {
2629
const providerCard = useGlobalStore(
27-
(s) => modelConfigSelectors.modelSelectList(s).find((s) => s.id === provider),
30+
(s) => modelProviderSelectors.providerModelList(s).find((s) => s.id === provider),
2831
isEqual,
2932
);
33+
const providerConfig = useGlobalStore((s) =>
34+
modelConfigSelectors.providerConfig(provider as GlobalLLMProviderKey)(s),
35+
);
36+
3037
const defaultEnableModel = providerCard ? filterEnabledModels(providerCard) : [];
3138

3239
return (
33-
<Select
40+
<Select<string[]>
3441
allowClear
3542
defaultValue={defaultEnableModel}
3643
mode="tags"
44+
onChange={(value) => {
45+
onChange?.(value.filter(Boolean));
46+
}}
3747
optionFilterProp="label"
3848
optionRender={({ label, value }) => (
3949
<OptionRender displayName={label as string} id={value as string} />
@@ -45,6 +55,7 @@ const CustomModelSelect = memo<CustomModelSelectProps>(({ provider, placeholder
4555
placeholder={placeholder}
4656
popupClassName={cx(popup)}
4757
popupMatchSelectWidth={false}
58+
value={providerConfig?.models.filter(Boolean)}
4859
/>
4960
);
5061
});

src/components/ModelSelect/index.tsx

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ interface ModelInfoTagsProps extends ChatModelCard {
6363
}
6464
export const ModelInfoTags = memo<ModelInfoTagsProps>(
6565
({ directionReverse, placement = 'right', ...model }) => {
66-
const { t } = useTranslation('common');
66+
const { t } = useTranslation('components');
6767
const { styles, cx } = useStyles();
6868

6969
return (

src/features/AgentSetting/AgentConfig/ModelSelect.tsx

+6-10
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,16 @@ interface ModelOption {
2525

2626
const ModelSelect = memo(() => {
2727
const [model, updateConfig] = useStore((s) => [s.config.model, s.setAgentConfig]);
28-
const select = useGlobalStore(modelConfigSelectors.modelSelectList, isEqual);
28+
const enabledList = useGlobalStore(modelConfigSelectors.enabledModelProviderList, isEqual);
2929
const { styles } = useStyles();
3030

31-
const enabledList = select.filter((s) => s.enabled);
32-
3331
const options = useMemo<SelectProps['options']>(() => {
3432
const getChatModels = (provider: ModelProviderCard) =>
35-
provider.chatModels
36-
.filter((c) => !c.hidden)
37-
.map((model) => ({
38-
label: <ModelItemRender {...model} />,
39-
provider: provider.id,
40-
value: model.id,
41-
}));
33+
provider.chatModels.map((model) => ({
34+
label: <ModelItemRender {...model} />,
35+
provider: provider.id,
36+
value: model.id,
37+
}));
4238

4339
if (enabledList.length === 1) {
4440
const provider = enabledList[0];

src/features/ModelSwitchPanel/index.tsx

+39-12
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1+
import { Icon } from '@lobehub/ui';
12
import { Dropdown } from 'antd';
23
import { createStyles } from 'antd-style';
34
import isEqual from 'fast-deep-equal';
5+
import { LucideArrowRight } from 'lucide-react';
6+
import { useRouter } from 'next/navigation';
47
import { PropsWithChildren, memo, useMemo } from 'react';
8+
import { useTranslation } from 'react-i18next';
9+
import { Flexbox } from 'react-layout-kit';
510

611
import { ModelItemRender, ProviderItemRender } from '@/components/ModelSelect';
712
import { useGlobalStore } from '@/store/global';
813
import { modelConfigSelectors } from '@/store/global/selectors';
914
import { useSessionStore } from '@/store/session';
1015
import { agentSelectors } from '@/store/session/selectors';
1116
import { ModelProviderCard } from '@/types/llm';
17+
import { withBasePath } from '@/utils/basePath';
1218

1319
const useStyles = createStyles(({ css, prefixCls }) => ({
1420
menu: css`
@@ -32,30 +38,51 @@ const useStyles = createStyles(({ css, prefixCls }) => ({
3238
}));
3339

3440
const ModelSwitchPanel = memo<PropsWithChildren>(({ children }) => {
35-
const { styles } = useStyles();
41+
const { t } = useTranslation('components');
42+
const { styles, theme } = useStyles();
3643
const model = useSessionStore(agentSelectors.currentAgentModel);
3744
const updateAgentConfig = useSessionStore((s) => s.updateAgentConfig);
3845

39-
const select = useGlobalStore(modelConfigSelectors.modelSelectList, isEqual);
40-
const enabledList = select.filter((s) => s.enabled);
46+
const router = useRouter();
47+
const enabledList = useGlobalStore(modelConfigSelectors.enabledModelProviderList, isEqual);
4148

4249
const items = useMemo(() => {
43-
const getModelItems = (provider: ModelProviderCard) =>
44-
provider.chatModels
45-
.filter((c) => !c.hidden)
46-
.map((model) => ({
47-
key: model.id,
48-
label: <ModelItemRender {...model} />,
49-
onClick: () => {
50-
updateAgentConfig({ model: model.id, provider: provider.id });
50+
const getModelItems = (provider: ModelProviderCard) => {
51+
const items = provider.chatModels.map((model) => ({
52+
key: model.id,
53+
label: <ModelItemRender {...model} />,
54+
onClick: () => {
55+
updateAgentConfig({ model: model.id, provider: provider.id });
56+
},
57+
}));
58+
59+
// if there is empty items, add a placeholder guide
60+
if (items.length === 0)
61+
return [
62+
{
63+
key: 'empty',
64+
label: (
65+
<Flexbox gap={8} horizontal style={{ color: theme.colorTextTertiary }}>
66+
{t('ModelSwitchPanel.emptyModel')}
67+
<Icon icon={LucideArrowRight} />
68+
</Flexbox>
69+
),
70+
onClick: () => {
71+
router.push(withBasePath('/settings/llm'));
72+
},
5173
},
52-
}));
74+
];
75+
76+
return items;
77+
};
5378

79+
// If there is only one provider, just remove the group, show model directly
5480
if (enabledList.length === 1) {
5581
const provider = enabledList[0];
5682
return getModelItems(provider);
5783
}
5884

85+
// otherwise show with provider group
5986
return enabledList.map((provider) => ({
6087
children: getModelItems(provider),
6188
key: provider.id,

src/locales/default/common.ts

-11
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,7 @@
11
export default {
2-
ModelSelect: {
3-
featureTag: {
4-
custom: '自定义模型,默认设定同时支持函数调用与视觉识别,请根据实际情况验证上述能力的可用性',
5-
file: '该模型支持上传文件读取与识别',
6-
functionCall: '该模型支持函数调用(Function Call)',
7-
tokens: '该模型单个会话最多支持 {{tokens}} Tokens',
8-
vision: '该模型支持视觉识别',
9-
},
10-
},
112
about: '关于',
123
advanceSettings: '高级设置',
13-
144
appInitializing: 'LobeChat 启动中,请耐心等待...',
15-
165
autoGenerate: '自动补全',
176
autoGenerateTooltip: '基于提示词自动补全助手描述',
187
cancel: '取消',

src/locales/default/components.ts

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
export default {
2+
ModelSelect: {
3+
featureTag: {
4+
custom: '自定义模型,默认设定同时支持函数调用与视觉识别,请根据实际情况验证上述能力的可用性',
5+
file: '该模型支持上传文件读取与识别',
6+
functionCall: '该模型支持函数调用(Function Call)',
7+
tokens: '该模型单个会话最多支持 {{tokens}} Tokens',
8+
vision: '该模型支持视觉识别',
9+
},
10+
},
11+
ModelSwitchPanel: {
12+
emptyModel: '没有启用的模型,请前往设置开启',
13+
provider: '提供商',
14+
},
15+
};

src/locales/default/index.ts

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import tool from '../default/tool';
22
import chat from './chat';
33
import common from './common';
4+
import components from './components';
45
import error from './error';
56
import market from './market';
67
import migration from './migration';
@@ -11,6 +12,7 @@ import welcome from './welcome';
1112
const resources = {
1213
chat,
1314
common,
15+
components,
1416
error,
1517
market,
1618
migration,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import { describe, expect, it } from 'vitest';
2+
3+
import { DEFAULT_SETTINGS } from '@/const/settings';
4+
import { modelProviderSelectors } from '@/store/global/slices/settings/selectors/modelProvider';
5+
import { agentSelectors } from '@/store/session/slices/agent';
6+
import { merge } from '@/utils/merge';
7+
8+
import { GlobalStore, useGlobalStore } from '../../../store';
9+
import { GlobalSettingsState, initialSettingsState } from '../initialState';
10+
import { modelConfigSelectors } from './modelConfig';
11+
12+
describe('modelConfigSelectors', () => {
13+
describe('modelSelectList', () => {
14+
it('visible', () => {
15+
const s = merge(initialSettingsState, {
16+
settings: {
17+
languageModel: {
18+
ollama: {
19+
models: ['llava'],
20+
},
21+
},
22+
},
23+
} as GlobalSettingsState) as unknown as GlobalStore;
24+
25+
const ollamaList = modelConfigSelectors.modelSelectList(s).find((r) => r.id === 'ollama');
26+
27+
expect(ollamaList?.chatModels).toEqual([]);
28+
});
29+
});
30+
});

src/store/global/slices/settings/selectors/modelConfig.ts

+40-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,30 @@
11
import { ModelProviderCard } from '@/types/llm';
2-
import { GlobalLLMProviderKey } from '@/types/settings';
2+
import { GeneralModelProviderConfig, GlobalLLMProviderKey } from '@/types/settings';
33

44
import { GlobalStore } from '../../../store';
55
import { modelProviderSelectors } from './modelProvider';
66
import { currentSettings } from './settings';
77

88
const modelProvider = (s: GlobalStore) => currentSettings(s).languageModel;
9-
const providerEnabled = (provider: GlobalLLMProviderKey) => (s: GlobalStore) =>
10-
currentSettings(s).languageModel[provider]?.enabled || false;
9+
10+
const providerConfig = (provider: string) => (s: GlobalStore) =>
11+
currentSettings(s).languageModel[provider as GlobalLLMProviderKey] as
12+
| GeneralModelProviderConfig
13+
| undefined;
14+
15+
const providerEnabled = (provider: GlobalLLMProviderKey) => (s: GlobalStore) => {
16+
// TODO: we need to migrate the 'openAI' key to 'openai'
17+
// @ts-ignore
18+
if (provider === 'openai') return true;
19+
20+
return currentSettings(s).languageModel[provider]?.enabled || false;
21+
};
22+
23+
const providerEnableModels =
24+
(provider: string) =>
25+
(s: GlobalStore): string[] | undefined => {
26+
return providerConfig(provider)(s)?.models;
27+
};
1128

1229
const openAIConfig = (s: GlobalStore) => modelProvider(s).openAI;
1330

@@ -67,14 +84,34 @@ const zerooneAPIKey = (s: GlobalStore) => modelProvider(s).zeroone.apiKey;
6784
const modelSelectList = (s: GlobalStore): ModelProviderCard[] => {
6885
return modelProviderSelectors.providerModelList(s).map((list) => ({
6986
...list,
87+
chatModels: list.chatModels.map((model) => {
88+
const models = providerEnableModels(list.id)(s);
89+
90+
if (!models) return model;
91+
92+
return {
93+
...model,
94+
hidden: !models?.some((m) => m === model.id),
95+
};
96+
}),
7097
enabled: providerEnabled(list.id as any)(s),
7198
}));
7299
};
73100

101+
const enabledModelProviderList = (s: GlobalStore): ModelProviderCard[] =>
102+
modelSelectList(s)
103+
.filter((s) => s.enabled)
104+
.map((provider) => ({
105+
...provider,
106+
chatModels: provider.chatModels.filter((model) => !model.hidden),
107+
}));
108+
74109
/* eslint-disable sort-keys-fix/sort-keys-fix, */
75110
export const modelConfigSelectors = {
76111
providerEnabled,
112+
providerConfig,
77113
modelSelectList,
114+
enabledModelProviderList,
78115

79116
// OpenAI
80117
openAIConfig,

src/types/settings/modelProvider.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export type CustomModels = { displayName: string; id: string }[];
22

3-
interface GeneralModelProviderConfig {
3+
export interface GeneralModelProviderConfig {
44
apiKey?: string;
55
enabled: boolean;
66
endpoint?: string;

0 commit comments

Comments
 (0)