diff --git a/src/prisma/generator/schema.py b/src/prisma/generator/schema.py index 0dd6a6b6a..88ceb6675 100644 --- a/src/prisma/generator/schema.py +++ b/src/prisma/generator/schema.py @@ -28,12 +28,12 @@ class PrismaType(BaseModel): subtypes: List['PrismaType'] = [] @classmethod - def from_subtypes(cls, subtypes: List['PrismaType'], **kwargs: Any) -> Union['PrismaUnion', 'PrismaAlias']: - """Return either a `PrismaUnion` or a `PrismaAlias` depending on the number of subtypes""" - if len(subtypes) > 1: - return PrismaUnion(subtypes=subtypes, **kwargs) + def from_variants(cls, variants: List['PrismaType'], **kwargs: Any) -> Union['PrismaUnion', 'PrismaAlias']: + """Return either a `PrismaUnion` or a `PrismaAlias` depending on the number of variants""" + if len(variants) > 1: + return PrismaUnion(variants=variants, **kwargs) - return PrismaAlias(subtypes=subtypes, **kwargs) + return PrismaAlias(subtypes=variants, **kwargs) class PrismaDict(PrismaType): @@ -44,7 +44,18 @@ class PrismaDict(PrismaType): class PrismaUnion(PrismaType): kind: Kind = Kind.union - subtypes: List[PrismaType] + variants: List[PrismaType] + + @root_validator(pre=True) + @classmethod + def add_subtypes(cls, values: Dict[str, Any]) -> Dict[str, Any]: + # add all variants as subtypes so that we don't have to special + # case rendering subtypes for unions + if 'variants' in values: + subtypes = values.get('subtypes', []) + subtypes.extend(values['variants']) + values['subtypes'] = subtypes + return values class PrismaEnum(PrismaType): @@ -94,7 +105,7 @@ class Config: def where_unique(self) -> PrismaType: info = self.info model = info.name - subtypes: List[PrismaType] = [ + variants: List[PrismaType] = [ PrismaDict( total=True, name=f'_{model}WhereUnique_{field.name}_Input', @@ -115,7 +126,7 @@ def where_unique(self) -> PrismaType: else: name = f'_{model}Compound{key.name}Key' - subtypes.append( + variants.append( PrismaDict( name=name, total=True, @@ -132,12 +143,12 @@ def where_unique(self) -> PrismaType: ) ) - return PrismaType.from_subtypes(subtypes, name=f'{model}WhereUniqueInput') + return PrismaType.from_variants(variants, name=f'{model}WhereUniqueInput') @cached_property def order_by(self) -> PrismaType: model = self.info.name - subtypes: List[PrismaType] = [ + variants: List[PrismaType] = [ PrismaDict( name=f'_{model}_{field.name}_OrderByInput', total=True, @@ -147,7 +158,7 @@ def order_by(self) -> PrismaType: ) for field in self.info.scalar_fields ] - return PrismaType.from_subtypes(subtypes, name=f'{model}OrderByInput') + return PrismaType.from_variants(variants, name=f'{model}OrderByInput') class ClientTypes(BaseModel): diff --git a/src/prisma/generator/templates/types.py.jinja b/src/prisma/generator/templates/types.py.jinja index cc9ac46a9..c3b8e9e1f 100644 --- a/src/prisma/generator/templates/types.py.jinja +++ b/src/prisma/generator/templates/types.py.jinja @@ -49,8 +49,8 @@ from .utils import _NoneType {{ type.name }} = {{ type.to }} {% elif type.kind == 'union' %} {{ type.name }} = Union[ - {% for subtype in type.subtypes %} - '{{ subtype.name }}', + {% for variant in type.variants %} + '{{ variant.name }}', {% endfor %} ] {% elif type.kind == 'typeddict' %}