10
10
Protocol ,
11
11
Required ,
12
12
TypedDict ,
13
+ TypeGuard ,
13
14
final ,
14
15
override ,
15
16
runtime_checkable ,
31
32
HttpxRequestFiles ,
32
33
)
33
34
from ._utils import (
35
+ PropertyInfo ,
34
36
is_list ,
35
37
is_given ,
36
38
is_mapping ,
39
41
strip_not_given ,
40
42
extract_type_arg ,
41
43
is_annotated_type ,
44
+ strip_annotated_type ,
42
45
)
43
46
from ._compat import (
44
47
PYDANTIC_V2 ,
55
58
)
56
59
from ._constants import RAW_RESPONSE_HEADER
57
60
61
+ if TYPE_CHECKING :
62
+ from pydantic_core .core_schema import ModelField , ModelFieldsSchema
63
+
58
64
__all__ = ["BaseModel" , "GenericModel" ]
59
65
60
66
_T = TypeVar ("_T" )
@@ -268,14 +274,18 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:
268
274
269
275
def is_basemodel (type_ : type ) -> bool :
270
276
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
271
- origin = get_origin (type_ ) or type_
272
277
if is_union (type_ ):
273
278
for variant in get_args (type_ ):
274
279
if is_basemodel (variant ):
275
280
return True
276
281
277
282
return False
278
283
284
+ return is_basemodel_type (type_ )
285
+
286
+
287
+ def is_basemodel_type (type_ : type ) -> TypeGuard [type [BaseModel ] | type [GenericModel ]]:
288
+ origin = get_origin (type_ ) or type_
279
289
return issubclass (origin , BaseModel ) or issubclass (origin , GenericModel )
280
290
281
291
@@ -286,7 +296,10 @@ def construct_type(*, value: object, type_: type) -> object:
286
296
"""
287
297
# unwrap `Annotated[T, ...]` -> `T`
288
298
if is_annotated_type (type_ ):
299
+ meta = get_args (type_ )[1 :]
289
300
type_ = extract_type_arg (type_ , 0 )
301
+ else :
302
+ meta = tuple ()
290
303
291
304
# we need to use the origin class for any types that are subscripted generics
292
305
# e.g. Dict[str, object]
@@ -299,6 +312,28 @@ def construct_type(*, value: object, type_: type) -> object:
299
312
except Exception :
300
313
pass
301
314
315
+ # if the type is a discriminated union then we want to construct the right variant
316
+ # in the union, even if the data doesn't match exactly, otherwise we'd break code
317
+ # that relies on the constructed class types, e.g.
318
+ #
319
+ # class FooType:
320
+ # kind: Literal['foo']
321
+ # value: str
322
+ #
323
+ # class BarType:
324
+ # kind: Literal['bar']
325
+ # value: int
326
+ #
327
+ # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
328
+ # we'd end up constructing `FooType` when it should be `BarType`.
329
+ discriminator = _build_discriminated_union_meta (union = type_ , meta_annotations = meta )
330
+ if discriminator and is_mapping (value ):
331
+ variant_value = value .get (discriminator .field_alias_from or discriminator .field_name )
332
+ if variant_value and isinstance (variant_value , str ):
333
+ variant_type = discriminator .mapping .get (variant_value )
334
+ if variant_type :
335
+ return construct_type (type_ = variant_type , value = value )
336
+
302
337
# if the data is not valid, use the first variant that doesn't fail while deserializing
303
338
for variant in args :
304
339
try :
@@ -356,6 +391,129 @@ def construct_type(*, value: object, type_: type) -> object:
356
391
return value
357
392
358
393
394
+ @runtime_checkable
395
+ class CachedDiscriminatorType (Protocol ):
396
+ __discriminator__ : DiscriminatorDetails
397
+
398
+
399
+ class DiscriminatorDetails :
400
+ field_name : str
401
+ """The name of the discriminator field in the variant class, e.g.
402
+
403
+ ```py
404
+ class Foo(BaseModel):
405
+ type: Literal['foo']
406
+ ```
407
+
408
+ Will result in field_name='type'
409
+ """
410
+
411
+ field_alias_from : str | None
412
+ """The name of the discriminator field in the API response, e.g.
413
+
414
+ ```py
415
+ class Foo(BaseModel):
416
+ type: Literal['foo'] = Field(alias='type_from_api')
417
+ ```
418
+
419
+ Will result in field_alias_from='type_from_api'
420
+ """
421
+
422
+ mapping : dict [str , type ]
423
+ """Mapping of discriminator value to variant type, e.g.
424
+
425
+ {'foo': FooVariant, 'bar': BarVariant}
426
+ """
427
+
428
+ def __init__ (
429
+ self ,
430
+ * ,
431
+ mapping : dict [str , type ],
432
+ discriminator_field : str ,
433
+ discriminator_alias : str | None ,
434
+ ) -> None :
435
+ self .mapping = mapping
436
+ self .field_name = discriminator_field
437
+ self .field_alias_from = discriminator_alias
438
+
439
+
440
+ def _build_discriminated_union_meta (* , union : type , meta_annotations : tuple [Any , ...]) -> DiscriminatorDetails | None :
441
+ if isinstance (union , CachedDiscriminatorType ):
442
+ return union .__discriminator__
443
+
444
+ discriminator_field_name : str | None = None
445
+
446
+ for annotation in meta_annotations :
447
+ if isinstance (annotation , PropertyInfo ) and annotation .discriminator is not None :
448
+ discriminator_field_name = annotation .discriminator
449
+ break
450
+
451
+ if not discriminator_field_name :
452
+ return None
453
+
454
+ mapping : dict [str , type ] = {}
455
+ discriminator_alias : str | None = None
456
+
457
+ for variant in get_args (union ):
458
+ variant = strip_annotated_type (variant )
459
+ if is_basemodel_type (variant ):
460
+ if PYDANTIC_V2 :
461
+ field = _extract_field_schema_pv2 (variant , discriminator_field_name )
462
+ if not field :
463
+ continue
464
+
465
+ # Note: if one variant defines an alias then they all should
466
+ discriminator_alias = field .get ("serialization_alias" )
467
+
468
+ field_schema = field ["schema" ]
469
+
470
+ if field_schema ["type" ] == "literal" :
471
+ for entry in field_schema ["expected" ]:
472
+ if isinstance (entry , str ):
473
+ mapping [entry ] = variant
474
+ else :
475
+ field_info = cast ("dict[str, FieldInfo]" , variant .__fields__ ).get (discriminator_field_name ) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
476
+ if not field_info :
477
+ continue
478
+
479
+ # Note: if one variant defines an alias then they all should
480
+ discriminator_alias = field_info .alias
481
+
482
+ if field_info .annotation and is_literal_type (field_info .annotation ):
483
+ for entry in get_args (field_info .annotation ):
484
+ if isinstance (entry , str ):
485
+ mapping [entry ] = variant
486
+
487
+ if not mapping :
488
+ return None
489
+
490
+ details = DiscriminatorDetails (
491
+ mapping = mapping ,
492
+ discriminator_field = discriminator_field_name ,
493
+ discriminator_alias = discriminator_alias ,
494
+ )
495
+ cast (CachedDiscriminatorType , union ).__discriminator__ = details
496
+ return details
497
+
498
+
499
+ def _extract_field_schema_pv2 (model : type [BaseModel ], field_name : str ) -> ModelField | None :
500
+ schema = model .__pydantic_core_schema__
501
+ if schema ["type" ] != "model" :
502
+ return None
503
+
504
+ fields_schema = schema ["schema" ]
505
+ if fields_schema ["type" ] != "model-fields" :
506
+ return None
507
+
508
+ fields_schema = cast ("ModelFieldsSchema" , fields_schema )
509
+
510
+ field = fields_schema ["fields" ].get (field_name )
511
+ if not field :
512
+ return None
513
+
514
+ return cast ("ModelField" , field ) # pyright: ignore[reportUnnecessaryCast]
515
+
516
+
359
517
def validate_type (* , type_ : type [_T ], value : object ) -> _T :
360
518
"""Strict validation that the given value matches the expected type"""
361
519
if inspect .isclass (type_ ) and issubclass (type_ , pydantic .BaseModel ):
0 commit comments