1
+ from __future__ import annotations
2
+
1
3
import copy
4
+ from typing import TYPE_CHECKING , Any , Generic , TypeVar , cast , overload
5
+
6
+ T = TypeVar ("T" )
7
+
8
+ if TYPE_CHECKING :
9
+ from collections .abc import Iterable , Iterator , Sequence
10
+
11
+ # The type aliases defined here are evaluated when the django-stubs mypy plugin
12
+ # loads this module, so they must be able to execute under the lowest supported
13
+ # Python VM:
14
+ # - typing.List, typing.Tuple become obsolete in Pyton 3.9
15
+ # - typing.Union becomes obsolete in Pyton 3.10
16
+ from typing import List , Tuple , Union
17
+
18
+ from django_stubs_ext import StrOrPromise
19
+
20
+ # The type argument 'T' to 'Choices' is the database representation type.
21
+ _Double = Tuple [T , StrOrPromise ]
22
+ _Triple = Tuple [T , str , StrOrPromise ]
23
+ _Group = Tuple [StrOrPromise , Sequence ["_Choice[T]" ]]
24
+ _Choice = Union [_Double [T ], _Triple [T ], _Group [T ]]
25
+ # Choices can only be given as a single string if 'T' is 'str'.
26
+ _GroupStr = Tuple [StrOrPromise , Sequence ["_ChoiceStr" ]]
27
+ _ChoiceStr = Union [str , _Double [str ], _Triple [str ], _GroupStr ]
28
+ # Note that we only accept lists and tuples in groups, not arbitrary sequences.
29
+ # However, annotating it as such causes many problems.
30
+
31
+ _DoubleRead = Union [_Double [T ], Tuple [StrOrPromise , Iterable ["_DoubleRead[T]" ]]]
32
+ _DoubleCollector = List [Union [_Double [T ], Tuple [StrOrPromise , "_DoubleCollector[T]" ]]]
33
+ _TripleCollector = List [Union [_Triple [T ], Tuple [StrOrPromise , "_TripleCollector[T]" ]]]
2
34
3
35
4
- class Choices :
36
+ class Choices ( Generic [ T ]) :
5
37
"""
6
38
A class to encapsulate handy functionality for lists of choices
7
39
for a Django model field.
@@ -41,36 +73,60 @@ class Choices:
41
73
42
74
"""
43
75
44
- def __init__ (self , * choices ):
76
+ @overload
77
+ def __init__ (self : Choices [str ], * choices : _ChoiceStr ):
78
+ ...
79
+
80
+ @overload
81
+ def __init__ (self , * choices : _Choice [T ]):
82
+ ...
83
+
84
+ def __init__ (self , * choices : _ChoiceStr | _Choice [T ]):
45
85
# list of choices expanded to triples - can include optgroups
46
- self ._triples = []
86
+ self ._triples : _TripleCollector [ T ] = []
47
87
# list of choices as (db, human-readable) - can include optgroups
48
- self ._doubles = []
88
+ self ._doubles : _DoubleCollector [ T ] = []
49
89
# dictionary mapping db representation to human-readable
50
- self ._display_map = {}
90
+ self ._display_map : dict [ T , StrOrPromise | list [ _Triple [ T ]]] = {}
51
91
# dictionary mapping Python identifier to db representation
52
- self ._identifier_map = {}
92
+ self ._identifier_map : dict [ str , T ] = {}
53
93
# set of db representations
54
- self ._db_values = set ()
94
+ self ._db_values : set [ T ] = set ()
55
95
56
96
self ._process (choices )
57
97
58
- def _store (self , triple , triple_collector , double_collector ):
98
+ def _store (
99
+ self ,
100
+ triple : tuple [T , str , StrOrPromise ],
101
+ triple_collector : _TripleCollector [T ],
102
+ double_collector : _DoubleCollector [T ]
103
+ ) -> None :
59
104
self ._identifier_map [triple [1 ]] = triple [0 ]
60
105
self ._display_map [triple [0 ]] = triple [2 ]
61
106
self ._db_values .add (triple [0 ])
62
107
triple_collector .append (triple )
63
108
double_collector .append ((triple [0 ], triple [2 ]))
64
109
65
- def _process (self , choices , triple_collector = None , double_collector = None ):
110
+ def _process (
111
+ self ,
112
+ choices : Iterable [_ChoiceStr | _Choice [T ]],
113
+ triple_collector : _TripleCollector [T ] | None = None ,
114
+ double_collector : _DoubleCollector [T ] | None = None
115
+ ) -> None :
66
116
if triple_collector is None :
67
117
triple_collector = self ._triples
68
118
if double_collector is None :
69
119
double_collector = self ._doubles
70
120
71
- store = lambda c : self ._store (c , triple_collector , double_collector )
121
+ def store (c : tuple [Any , str , StrOrPromise ]) -> None :
122
+ self ._store (c , triple_collector , double_collector )
72
123
73
124
for choice in choices :
125
+ # The type inference is not very accurate here:
126
+ # - we lied in the type aliases, stating groups contain an arbitrary Sequence
127
+ # rather than only list or tuple
128
+ # - there is no way to express that _ChoiceStr is only used when T=str
129
+ # - mypy 1.9.0 doesn't narrow types based on the value of len()
74
130
if isinstance (choice , (list , tuple )):
75
131
if len (choice ) == 3 :
76
132
store (choice )
@@ -79,13 +135,13 @@ def _process(self, choices, triple_collector=None, double_collector=None):
79
135
# option group
80
136
group_name = choice [0 ]
81
137
subchoices = choice [1 ]
82
- tc = []
138
+ tc : _TripleCollector [ T ] = []
83
139
triple_collector .append ((group_name , tc ))
84
- dc = []
140
+ dc : _DoubleCollector [ T ] = []
85
141
double_collector .append ((group_name , dc ))
86
142
self ._process (subchoices , tc , dc )
87
143
else :
88
- store ((choice [0 ], choice [0 ], choice [1 ]))
144
+ store ((choice [0 ], cast ( str , choice [0 ]), cast ( 'StrOrPromise' , choice [1 ]) ))
89
145
else :
90
146
raise ValueError (
91
147
"Choices can't take a list of length %s, only 2 or 3"
@@ -94,54 +150,74 @@ def _process(self, choices, triple_collector=None, double_collector=None):
94
150
else :
95
151
store ((choice , choice , choice ))
96
152
97
- def __len__ (self ):
153
+ def __len__ (self ) -> int :
98
154
return len (self ._doubles )
99
155
100
- def __iter__ (self ):
156
+ def __iter__ (self ) -> Iterator [ _DoubleRead [ T ]] :
101
157
return iter (self ._doubles )
102
158
103
- def __reversed__ (self ):
159
+ def __reversed__ (self ) -> Iterator [ _DoubleRead [ T ]] :
104
160
return reversed (self ._doubles )
105
161
106
- def __getattr__ (self , attname ) :
162
+ def __getattr__ (self , attname : str ) -> T :
107
163
try :
108
164
return self ._identifier_map [attname ]
109
165
except KeyError :
110
166
raise AttributeError (attname )
111
167
112
- def __getitem__ (self , key ) :
168
+ def __getitem__ (self , key : T ) -> StrOrPromise | Sequence [ _Triple [ T ]] :
113
169
return self ._display_map [key ]
114
170
115
- def __add__ (self , other ):
171
+ @overload
172
+ def __add__ (self : Choices [str ], other : Choices [str ] | Iterable [_ChoiceStr ]) -> Choices [str ]:
173
+ ...
174
+
175
+ @overload
176
+ def __add__ (self , other : Choices [T ] | Iterable [_Choice [T ]]) -> Choices [T ]:
177
+ ...
178
+
179
+ def __add__ (self , other : Choices [Any ] | Iterable [_ChoiceStr | _Choice [Any ]]) -> Choices [Any ]:
180
+ other_args : list [Any ]
116
181
if isinstance (other , self .__class__ ):
117
- other = other ._triples
182
+ other_args = other ._triples
118
183
else :
119
- other = list (other )
120
- return Choices (* (self ._triples + other ))
184
+ other_args = list (other )
185
+ return Choices (* (self ._triples + other_args ))
186
+
187
+ @overload
188
+ def __radd__ (self : Choices [str ], other : Iterable [_ChoiceStr ]) -> Choices [str ]:
189
+ ...
190
+
191
+ @overload
192
+ def __radd__ (self , other : Iterable [_Choice [T ]]) -> Choices [T ]:
193
+ ...
121
194
122
- def __radd__ (self , other ) :
195
+ def __radd__ (self , other : Iterable [ _ChoiceStr ] | Iterable [ _Choice [ T ]]) -> Choices [ Any ] :
123
196
# radd is never called for matching types, so we don't check here
124
- other = list (other )
125
- return Choices (* (other + self ._triples ))
197
+ other_args = list (other )
198
+ # The exact type of 'other' depends on our type argument 'T', which
199
+ # is expressed in the overloading, but lost within this method body.
200
+ return Choices (* (other_args + self ._triples )) # type: ignore[arg-type]
126
201
127
- def __eq__ (self , other ) :
202
+ def __eq__ (self , other : object ) -> bool :
128
203
if isinstance (other , self .__class__ ):
129
204
return self ._triples == other ._triples
130
205
return False
131
206
132
- def __repr__ (self ):
207
+ def __repr__ (self ) -> str :
133
208
return '{}({})' .format (
134
209
self .__class__ .__name__ ,
135
210
', ' .join ("%s" % repr (i ) for i in self ._triples )
136
211
)
137
212
138
- def __contains__ (self , item ) :
213
+ def __contains__ (self , item : T ) -> bool :
139
214
return item in self ._db_values
140
215
141
- def __deepcopy__ (self , memo ):
142
- return self .__class__ (* copy .deepcopy (self ._triples , memo ))
216
+ def __deepcopy__ (self , memo : dict [int , Any ] | None ) -> Choices [T ]:
217
+ args : list [Any ] = copy .deepcopy (self ._triples , memo )
218
+ return self .__class__ (* args )
143
219
144
- def subset (self , * new_identifiers ) :
220
+ def subset (self , * new_identifiers : str ) -> Choices [ T ] :
145
221
identifiers = set (self ._identifier_map .keys ())
146
222
147
223
if not identifiers .issuperset (new_identifiers ):
@@ -150,7 +226,8 @@ def subset(self, *new_identifiers):
150
226
identifiers .symmetric_difference (new_identifiers ),
151
227
)
152
228
153
- return self . __class__ ( * [
229
+ args : list [ Any ] = [
154
230
choice for choice in self ._triples
155
231
if choice [1 ] in new_identifiers
156
- ])
232
+ ]
233
+ return self .__class__ (* args )
0 commit comments