Skip to content

Commit 37b3038

Browse files
stainless-app[bot]stainless-bot
authored andcommitted
feat: OpenAPI spec update via Stainless API (#115)
1 parent a93e4df commit 37b3038

File tree

3 files changed

+343
-2
lines changed

3 files changed

+343
-2
lines changed

src/cloudflare/_models.py

+159-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Protocol,
1111
Required,
1212
TypedDict,
13+
TypeGuard,
1314
final,
1415
override,
1516
runtime_checkable,
@@ -31,6 +32,7 @@
3132
HttpxRequestFiles,
3233
)
3334
from ._utils import (
35+
PropertyInfo,
3436
is_list,
3537
is_given,
3638
is_mapping,
@@ -39,6 +41,7 @@
3941
strip_not_given,
4042
extract_type_arg,
4143
is_annotated_type,
44+
strip_annotated_type,
4245
)
4346
from ._compat import (
4447
PYDANTIC_V2,
@@ -55,6 +58,9 @@
5558
)
5659
from ._constants import RAW_RESPONSE_HEADER
5760

61+
if TYPE_CHECKING:
62+
from pydantic_core.core_schema import ModelField, ModelFieldsSchema
63+
5864
__all__ = ["BaseModel", "GenericModel"]
5965

6066
_T = TypeVar("_T")
@@ -268,14 +274,18 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:
268274

269275
def is_basemodel(type_: type) -> bool:
270276
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
271-
origin = get_origin(type_) or type_
272277
if is_union(type_):
273278
for variant in get_args(type_):
274279
if is_basemodel(variant):
275280
return True
276281

277282
return False
278283

284+
return is_basemodel_type(type_)
285+
286+
287+
def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
288+
origin = get_origin(type_) or type_
279289
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
280290

281291

@@ -286,7 +296,10 @@ def construct_type(*, value: object, type_: type) -> object:
286296
"""
287297
# unwrap `Annotated[T, ...]` -> `T`
288298
if is_annotated_type(type_):
299+
meta = get_args(type_)[1:]
289300
type_ = extract_type_arg(type_, 0)
301+
else:
302+
meta = tuple()
290303

291304
# we need to use the origin class for any types that are subscripted generics
292305
# e.g. Dict[str, object]
@@ -299,6 +312,28 @@ def construct_type(*, value: object, type_: type) -> object:
299312
except Exception:
300313
pass
301314

315+
# if the type is a discriminated union then we want to construct the right variant
316+
# in the union, even if the data doesn't match exactly, otherwise we'd break code
317+
# that relies on the constructed class types, e.g.
318+
#
319+
# class FooType:
320+
# kind: Literal['foo']
321+
# value: str
322+
#
323+
# class BarType:
324+
# kind: Literal['bar']
325+
# value: int
326+
#
327+
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
328+
# we'd end up constructing `FooType` when it should be `BarType`.
329+
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
330+
if discriminator and is_mapping(value):
331+
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
332+
if variant_value and isinstance(variant_value, str):
333+
variant_type = discriminator.mapping.get(variant_value)
334+
if variant_type:
335+
return construct_type(type_=variant_type, value=value)
336+
302337
# if the data is not valid, use the first variant that doesn't fail while deserializing
303338
for variant in args:
304339
try:
@@ -356,6 +391,129 @@ def construct_type(*, value: object, type_: type) -> object:
356391
return value
357392

358393

394+
@runtime_checkable
395+
class CachedDiscriminatorType(Protocol):
396+
__discriminator__: DiscriminatorDetails
397+
398+
399+
class DiscriminatorDetails:
400+
field_name: str
401+
"""The name of the discriminator field in the variant class, e.g.
402+
403+
```py
404+
class Foo(BaseModel):
405+
type: Literal['foo']
406+
```
407+
408+
Will result in field_name='type'
409+
"""
410+
411+
field_alias_from: str | None
412+
"""The name of the discriminator field in the API response, e.g.
413+
414+
```py
415+
class Foo(BaseModel):
416+
type: Literal['foo'] = Field(alias='type_from_api')
417+
```
418+
419+
Will result in field_alias_from='type_from_api'
420+
"""
421+
422+
mapping: dict[str, type]
423+
"""Mapping of discriminator value to variant type, e.g.
424+
425+
{'foo': FooVariant, 'bar': BarVariant}
426+
"""
427+
428+
def __init__(
429+
self,
430+
*,
431+
mapping: dict[str, type],
432+
discriminator_field: str,
433+
discriminator_alias: str | None,
434+
) -> None:
435+
self.mapping = mapping
436+
self.field_name = discriminator_field
437+
self.field_alias_from = discriminator_alias
438+
439+
440+
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
441+
if isinstance(union, CachedDiscriminatorType):
442+
return union.__discriminator__
443+
444+
discriminator_field_name: str | None = None
445+
446+
for annotation in meta_annotations:
447+
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
448+
discriminator_field_name = annotation.discriminator
449+
break
450+
451+
if not discriminator_field_name:
452+
return None
453+
454+
mapping: dict[str, type] = {}
455+
discriminator_alias: str | None = None
456+
457+
for variant in get_args(union):
458+
variant = strip_annotated_type(variant)
459+
if is_basemodel_type(variant):
460+
if PYDANTIC_V2:
461+
field = _extract_field_schema_pv2(variant, discriminator_field_name)
462+
if not field:
463+
continue
464+
465+
# Note: if one variant defines an alias then they all should
466+
discriminator_alias = field.get("serialization_alias")
467+
468+
field_schema = field["schema"]
469+
470+
if field_schema["type"] == "literal":
471+
for entry in field_schema["expected"]:
472+
if isinstance(entry, str):
473+
mapping[entry] = variant
474+
else:
475+
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
476+
if not field_info:
477+
continue
478+
479+
# Note: if one variant defines an alias then they all should
480+
discriminator_alias = field_info.alias
481+
482+
if field_info.annotation and is_literal_type(field_info.annotation):
483+
for entry in get_args(field_info.annotation):
484+
if isinstance(entry, str):
485+
mapping[entry] = variant
486+
487+
if not mapping:
488+
return None
489+
490+
details = DiscriminatorDetails(
491+
mapping=mapping,
492+
discriminator_field=discriminator_field_name,
493+
discriminator_alias=discriminator_alias,
494+
)
495+
cast(CachedDiscriminatorType, union).__discriminator__ = details
496+
return details
497+
498+
499+
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
500+
schema = model.__pydantic_core_schema__
501+
if schema["type"] != "model":
502+
return None
503+
504+
fields_schema = schema["schema"]
505+
if fields_schema["type"] != "model-fields":
506+
return None
507+
508+
fields_schema = cast("ModelFieldsSchema", fields_schema)
509+
510+
field = fields_schema["fields"].get(field_name)
511+
if not field:
512+
return None
513+
514+
return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
515+
516+
359517
def validate_type(*, type_: type[_T], value: object) -> _T:
360518
"""Strict validation that the given value matches the expected type"""
361519
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):

src/cloudflare/_utils/_transform.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,24 @@ class MyParams(TypedDict):
5151
alias: str | None
5252
format: PropertyFormat | None
5353
format_template: str | None
54+
discriminator: str | None
5455

5556
def __init__(
5657
self,
5758
*,
5859
alias: str | None = None,
5960
format: PropertyFormat | None = None,
6061
format_template: str | None = None,
62+
discriminator: str | None = None,
6163
) -> None:
6264
self.alias = alias
6365
self.format = format
6466
self.format_template = format_template
67+
self.discriminator = discriminator
6568

6669
@override
6770
def __repr__(self) -> str:
68-
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}')"
71+
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"
6972

7073

7174
def maybe_transform(

0 commit comments

Comments
 (0)