Skip to content
This repository was archived by the owner on Apr 15, 2025. It is now read-only.

Commit e489b63

Browse files
chore(internal): restructure PrismaUnion representation (#926)
We shouldn't be overloading `subtypes` for unions as we may need to distinguish between `subtypes` and `variants`, e.g. each variant has the same shared parent class
1 parent 3e615a8 commit e489b63

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

src/prisma/generator/schema.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ class PrismaType(BaseModel):
2828
subtypes: List['PrismaType'] = []
2929

3030
@classmethod
31-
def from_subtypes(cls, subtypes: List['PrismaType'], **kwargs: Any) -> Union['PrismaUnion', 'PrismaAlias']:
32-
"""Return either a `PrismaUnion` or a `PrismaAlias` depending on the number of subtypes"""
33-
if len(subtypes) > 1:
34-
return PrismaUnion(subtypes=subtypes, **kwargs)
31+
def from_variants(cls, variants: List['PrismaType'], **kwargs: Any) -> Union['PrismaUnion', 'PrismaAlias']:
32+
"""Return either a `PrismaUnion` or a `PrismaAlias` depending on the number of variants"""
33+
if len(variants) > 1:
34+
return PrismaUnion(variants=variants, **kwargs)
3535

36-
return PrismaAlias(subtypes=subtypes, **kwargs)
36+
return PrismaAlias(subtypes=variants, **kwargs)
3737

3838

3939
class PrismaDict(PrismaType):
@@ -44,7 +44,18 @@ class PrismaDict(PrismaType):
4444

4545
class PrismaUnion(PrismaType):
4646
kind: Kind = Kind.union
47-
subtypes: List[PrismaType]
47+
variants: List[PrismaType]
48+
49+
@root_validator(pre=True)
50+
@classmethod
51+
def add_subtypes(cls, values: Dict[str, Any]) -> Dict[str, Any]:
52+
# add all variants as subtypes so that we don't have to special
53+
# case rendering subtypes for unions
54+
if 'variants' in values:
55+
subtypes = values.get('subtypes', [])
56+
subtypes.extend(values['variants'])
57+
values['subtypes'] = subtypes
58+
return values
4859

4960

5061
class PrismaEnum(PrismaType):
@@ -94,7 +105,7 @@ class Config:
94105
def where_unique(self) -> PrismaType:
95106
info = self.info
96107
model = info.name
97-
subtypes: List[PrismaType] = [
108+
variants: List[PrismaType] = [
98109
PrismaDict(
99110
total=True,
100111
name=f'_{model}WhereUnique_{field.name}_Input',
@@ -115,7 +126,7 @@ def where_unique(self) -> PrismaType:
115126
else:
116127
name = f'_{model}Compound{key.name}Key'
117128

118-
subtypes.append(
129+
variants.append(
119130
PrismaDict(
120131
name=name,
121132
total=True,
@@ -132,12 +143,12 @@ def where_unique(self) -> PrismaType:
132143
)
133144
)
134145

135-
return PrismaType.from_subtypes(subtypes, name=f'{model}WhereUniqueInput')
146+
return PrismaType.from_variants(variants, name=f'{model}WhereUniqueInput')
136147

137148
@cached_property
138149
def order_by(self) -> PrismaType:
139150
model = self.info.name
140-
subtypes: List[PrismaType] = [
151+
variants: List[PrismaType] = [
141152
PrismaDict(
142153
name=f'_{model}_{field.name}_OrderByInput',
143154
total=True,
@@ -147,7 +158,7 @@ def order_by(self) -> PrismaType:
147158
)
148159
for field in self.info.scalar_fields
149160
]
150-
return PrismaType.from_subtypes(subtypes, name=f'{model}OrderByInput')
161+
return PrismaType.from_variants(variants, name=f'{model}OrderByInput')
151162

152163

153164
class ClientTypes(BaseModel):

src/prisma/generator/templates/types.py.jinja

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ from .utils import _NoneType
4949
{{ type.name }} = {{ type.to }}
5050
{% elif type.kind == 'union' %}
5151
{{ type.name }} = Union[
52-
{% for subtype in type.subtypes %}
53-
'{{ subtype.name }}',
52+
{% for variant in type.variants %}
53+
'{{ variant.name }}',
5454
{% endfor %}
5555
]
5656
{% elif type.kind == 'typeddict' %}

0 commit comments

Comments
 (0)