7
7
from typing import Optional
8
8
9
9
from pydantic import BaseModel , ConfigDict , Field
10
- from sqlalchemy import or_
11
10
12
11
from constants import HIDDEN_VALUE
13
12
from core .entities .model_entities import ModelStatus , ModelWithProviderEntity , SimpleModelProviderEntity
@@ -180,37 +179,35 @@ def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
180
179
else [],
181
180
)
182
181
183
- def custom_credentials_validate (self , credentials : dict ) -> tuple [ Provider | None , dict ] :
182
+ def _get_custom_provider_credentials (self ) -> Provider | None :
184
183
"""
185
- Validate custom credentials.
186
- :param credentials: provider credentials
187
- :return:
184
+ Get custom provider credentials.
188
185
"""
189
186
# get provider
190
187
model_provider_id = ModelProviderID (self .provider .provider )
188
+ provider_names = [self .provider .provider ]
191
189
if model_provider_id .is_langgenius ():
192
- provider_record = (
193
- db .session .query (Provider )
194
- .filter (
195
- Provider .tenant_id == self .tenant_id ,
196
- Provider .provider_type == ProviderType .CUSTOM .value ,
197
- or_ (
198
- Provider .provider_name == model_provider_id .provider_name ,
199
- Provider .provider_name == self .provider .provider ,
200
- ),
201
- )
202
- .first ()
203
- )
204
- else :
205
- provider_record = (
206
- db .session .query (Provider )
207
- .filter (
208
- Provider .tenant_id == self .tenant_id ,
209
- Provider .provider_type == ProviderType .CUSTOM .value ,
210
- Provider .provider_name == self .provider .provider ,
211
- )
212
- .first ()
190
+ provider_names .append (model_provider_id .provider_name )
191
+
192
+ provider_record = (
193
+ db .session .query (Provider )
194
+ .filter (
195
+ Provider .tenant_id == self .tenant_id ,
196
+ Provider .provider_type == ProviderType .CUSTOM .value ,
197
+ Provider .provider_name .in_ (provider_names ),
213
198
)
199
+ .first ()
200
+ )
201
+
202
+ return provider_record
203
+
204
+ def custom_credentials_validate (self , credentials : dict ) -> tuple [Provider | None , dict ]:
205
+ """
206
+ Validate custom credentials.
207
+ :param credentials: provider credentials
208
+ :return:
209
+ """
210
+ provider_record = self ._get_custom_provider_credentials ()
214
211
215
212
# Get provider credential secret variables
216
213
provider_credential_secret_variables = self .extract_secret_variables (
@@ -291,18 +288,7 @@ def delete_custom_credentials(self) -> None:
291
288
:return:
292
289
"""
293
290
# get provider
294
- provider_record = (
295
- db .session .query (Provider )
296
- .filter (
297
- Provider .tenant_id == self .tenant_id ,
298
- or_ (
299
- Provider .provider_name == ModelProviderID (self .provider .provider ).plugin_name ,
300
- Provider .provider_name == self .provider .provider ,
301
- ),
302
- Provider .provider_type == ProviderType .CUSTOM .value ,
303
- )
304
- .first ()
305
- )
291
+ provider_record = self ._get_custom_provider_credentials ()
306
292
307
293
# delete provider
308
294
if provider_record :
@@ -349,29 +335,47 @@ def get_custom_model_credentials(
349
335
350
336
return None
351
337
352
- def custom_model_credentials_validate (
353
- self , model_type : ModelType , model : str , credentials : dict
354
- ) -> tuple [ProviderModel | None , dict ]:
338
+ def _get_custom_model_credentials (
339
+ self ,
340
+ model_type : ModelType ,
341
+ model : str ,
342
+ ) -> ProviderModel | None :
355
343
"""
356
- Validate custom model credentials.
357
-
358
- :param model_type: model type
359
- :param model: model name
360
- :param credentials: model credentials
361
- :return:
344
+ Get custom model credentials.
362
345
"""
363
346
# get provider model
347
+ model_provider_id = ModelProviderID (self .provider .provider )
348
+ provider_names = [self .provider .provider ]
349
+ if model_provider_id .is_langgenius ():
350
+ provider_names .append (model_provider_id .provider_name )
351
+
364
352
provider_model_record = (
365
353
db .session .query (ProviderModel )
366
354
.filter (
367
355
ProviderModel .tenant_id == self .tenant_id ,
368
- ProviderModel .provider_name == self . provider . provider ,
356
+ ProviderModel .provider_name . in_ ( provider_names ) ,
369
357
ProviderModel .model_name == model ,
370
358
ProviderModel .model_type == model_type .to_origin_model_type (),
371
359
)
372
360
.first ()
373
361
)
374
362
363
+ return provider_model_record
364
+
365
+ def custom_model_credentials_validate (
366
+ self , model_type : ModelType , model : str , credentials : dict
367
+ ) -> tuple [ProviderModel | None , dict ]:
368
+ """
369
+ Validate custom model credentials.
370
+
371
+ :param model_type: model type
372
+ :param model: model name
373
+ :param credentials: model credentials
374
+ :return:
375
+ """
376
+ # get provider model
377
+ provider_model_record = self ._get_custom_model_credentials (model_type , model )
378
+
375
379
# Get provider credential secret variables
376
380
provider_credential_secret_variables = self .extract_secret_variables (
377
381
self .provider .model_credential_schema .credential_form_schemas
@@ -451,16 +455,7 @@ def delete_custom_model_credentials(self, model_type: ModelType, model: str) ->
451
455
:return:
452
456
"""
453
457
# get provider model
454
- provider_model_record = (
455
- db .session .query (ProviderModel )
456
- .filter (
457
- ProviderModel .tenant_id == self .tenant_id ,
458
- ProviderModel .provider_name == self .provider .provider ,
459
- ProviderModel .model_name == model ,
460
- ProviderModel .model_type == model_type .to_origin_model_type (),
461
- )
462
- .first ()
463
- )
458
+ provider_model_record = self ._get_custom_model_credentials (model_type , model )
464
459
465
460
# delete provider model
466
461
if provider_model_record :
@@ -475,24 +470,35 @@ def delete_custom_model_credentials(self, model_type: ModelType, model: str) ->
475
470
476
471
provider_model_credentials_cache .delete ()
477
472
478
- def enable_model (self , model_type : ModelType , model : str ) -> ProviderModelSetting :
473
+ def _get_provider_model_setting (self , model_type : ModelType , model : str ) -> ProviderModelSetting | None :
479
474
"""
480
- Enable model.
481
- :param model_type: model type
482
- :param model: model name
483
- :return:
475
+ Get provider model setting.
484
476
"""
485
- model_setting = (
477
+ model_provider_id = ModelProviderID (self .provider .provider )
478
+ provider_names = [self .provider .provider ]
479
+ if model_provider_id .is_langgenius ():
480
+ provider_names .append (model_provider_id .provider_name )
481
+
482
+ return (
486
483
db .session .query (ProviderModelSetting )
487
484
.filter (
488
485
ProviderModelSetting .tenant_id == self .tenant_id ,
489
- ProviderModelSetting .provider_name == self . provider . provider ,
486
+ ProviderModelSetting .provider_name . in_ ( provider_names ) ,
490
487
ProviderModelSetting .model_type == model_type .to_origin_model_type (),
491
488
ProviderModelSetting .model_name == model ,
492
489
)
493
490
.first ()
494
491
)
495
492
493
+ def enable_model (self , model_type : ModelType , model : str ) -> ProviderModelSetting :
494
+ """
495
+ Enable model.
496
+ :param model_type: model type
497
+ :param model: model name
498
+ :return:
499
+ """
500
+ model_setting = self ._get_provider_model_setting (model_type , model )
501
+
496
502
if model_setting :
497
503
model_setting .enabled = True
498
504
model_setting .updated_at = datetime .datetime .now (datetime .UTC ).replace (tzinfo = None )
@@ -516,16 +522,7 @@ def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetti
516
522
:param model: model name
517
523
:return:
518
524
"""
519
- model_setting = (
520
- db .session .query (ProviderModelSetting )
521
- .filter (
522
- ProviderModelSetting .tenant_id == self .tenant_id ,
523
- ProviderModelSetting .provider_name == self .provider .provider ,
524
- ProviderModelSetting .model_type == model_type .to_origin_model_type (),
525
- ProviderModelSetting .model_name == model ,
526
- )
527
- .first ()
528
- )
525
+ model_setting = self ._get_provider_model_setting (model_type , model )
529
526
530
527
if model_setting :
531
528
model_setting .enabled = False
@@ -550,13 +547,24 @@ def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optio
550
547
:param model: model name
551
548
:return:
552
549
"""
550
+ return self ._get_provider_model_setting (model_type , model )
551
+
552
+ def _get_load_balancing_config (self , model_type : ModelType , model : str ) -> Optional [LoadBalancingModelConfig ]:
553
+ """
554
+ Get load balancing config.
555
+ """
556
+ model_provider_id = ModelProviderID (self .provider .provider )
557
+ provider_names = [self .provider .provider ]
558
+ if model_provider_id .is_langgenius ():
559
+ provider_names .append (model_provider_id .provider_name )
560
+
553
561
return (
554
- db .session .query (ProviderModelSetting )
562
+ db .session .query (LoadBalancingModelConfig )
555
563
.filter (
556
- ProviderModelSetting .tenant_id == self .tenant_id ,
557
- ProviderModelSetting .provider_name == self . provider . provider ,
558
- ProviderModelSetting .model_type == model_type .to_origin_model_type (),
559
- ProviderModelSetting .model_name == model ,
564
+ LoadBalancingModelConfig .tenant_id == self .tenant_id ,
565
+ LoadBalancingModelConfig .provider_name . in_ ( provider_names ) ,
566
+ LoadBalancingModelConfig .model_type == model_type .to_origin_model_type (),
567
+ LoadBalancingModelConfig .model_name == model ,
560
568
)
561
569
.first ()
562
570
)
@@ -568,11 +576,16 @@ def enable_model_load_balancing(self, model_type: ModelType, model: str) -> Prov
568
576
:param model: model name
569
577
:return:
570
578
"""
579
+ model_provider_id = ModelProviderID (self .provider .provider )
580
+ provider_names = [self .provider .provider ]
581
+ if model_provider_id .is_langgenius ():
582
+ provider_names .append (model_provider_id .provider_name )
583
+
571
584
load_balancing_config_count = (
572
585
db .session .query (LoadBalancingModelConfig )
573
586
.filter (
574
587
LoadBalancingModelConfig .tenant_id == self .tenant_id ,
575
- LoadBalancingModelConfig .provider_name == self . provider . provider ,
588
+ LoadBalancingModelConfig .provider_name . in_ ( provider_names ) ,
576
589
LoadBalancingModelConfig .model_type == model_type .to_origin_model_type (),
577
590
LoadBalancingModelConfig .model_name == model ,
578
591
)
@@ -582,16 +595,7 @@ def enable_model_load_balancing(self, model_type: ModelType, model: str) -> Prov
582
595
if load_balancing_config_count <= 1 :
583
596
raise ValueError ("Model load balancing configuration must be more than 1." )
584
597
585
- model_setting = (
586
- db .session .query (ProviderModelSetting )
587
- .filter (
588
- ProviderModelSetting .tenant_id == self .tenant_id ,
589
- ProviderModelSetting .provider_name == self .provider .provider ,
590
- ProviderModelSetting .model_type == model_type .to_origin_model_type (),
591
- ProviderModelSetting .model_name == model ,
592
- )
593
- .first ()
594
- )
598
+ model_setting = self ._get_provider_model_setting (model_type , model )
595
599
596
600
if model_setting :
597
601
model_setting .load_balancing_enabled = True
@@ -616,11 +620,16 @@ def disable_model_load_balancing(self, model_type: ModelType, model: str) -> Pro
616
620
:param model: model name
617
621
:return:
618
622
"""
623
+ model_provider_id = ModelProviderID (self .provider .provider )
624
+ provider_names = [self .provider .provider ]
625
+ if model_provider_id .is_langgenius ():
626
+ provider_names .append (model_provider_id .provider_name )
627
+
619
628
model_setting = (
620
629
db .session .query (ProviderModelSetting )
621
630
.filter (
622
631
ProviderModelSetting .tenant_id == self .tenant_id ,
623
- ProviderModelSetting .provider_name == self . provider . provider ,
632
+ ProviderModelSetting .provider_name . in_ ( provider_names ) ,
624
633
ProviderModelSetting .model_type == model_type .to_origin_model_type (),
625
634
ProviderModelSetting .model_name == model ,
626
635
)
@@ -677,11 +686,16 @@ def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
677
686
return
678
687
679
688
# get preferred provider
689
+ model_provider_id = ModelProviderID (self .provider .provider )
690
+ provider_names = [self .provider .provider ]
691
+ if model_provider_id .is_langgenius ():
692
+ provider_names .append (model_provider_id .provider_name )
693
+
680
694
preferred_model_provider = (
681
695
db .session .query (TenantPreferredModelProvider )
682
696
.filter (
683
697
TenantPreferredModelProvider .tenant_id == self .tenant_id ,
684
- TenantPreferredModelProvider .provider_name == self . provider . provider ,
698
+ TenantPreferredModelProvider .provider_name . in_ ( provider_names ) ,
685
699
)
686
700
.first ()
687
701
)
0 commit comments