Skip to content

Commit 731ed80

Browse files
authored
Merge pull request #603 from ProtixIT/type-annotations
Add type annotations
2 parents 95664ef + 9dc2247 commit 731ed80

30 files changed

+1145
-645
lines changed

.coveragerc

+6
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
11
[run]
22
include = model_utils/*.py
3+
4+
[report]
5+
exclude_also =
6+
# Exclusive to mypy:
7+
if TYPE_CHECKING:$
8+
\.\.\.$

model_utils/choices.py

+110-33
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,39 @@
1+
from __future__ import annotations
2+
13
import copy
4+
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload
5+
6+
T = TypeVar("T")
7+
8+
if TYPE_CHECKING:
9+
from collections.abc import Iterable, Iterator, Sequence
10+
11+
# The type aliases defined here are evaluated when the django-stubs mypy plugin
12+
# loads this module, so they must be able to execute under the lowest supported
13+
# Python VM:
14+
# - typing.List, typing.Tuple become obsolete in Pyton 3.9
15+
# - typing.Union becomes obsolete in Pyton 3.10
16+
from typing import List, Tuple, Union
17+
18+
from django_stubs_ext import StrOrPromise
19+
20+
# The type argument 'T' to 'Choices' is the database representation type.
21+
_Double = Tuple[T, StrOrPromise]
22+
_Triple = Tuple[T, str, StrOrPromise]
23+
_Group = Tuple[StrOrPromise, Sequence["_Choice[T]"]]
24+
_Choice = Union[_Double[T], _Triple[T], _Group[T]]
25+
# Choices can only be given as a single string if 'T' is 'str'.
26+
_GroupStr = Tuple[StrOrPromise, Sequence["_ChoiceStr"]]
27+
_ChoiceStr = Union[str, _Double[str], _Triple[str], _GroupStr]
28+
# Note that we only accept lists and tuples in groups, not arbitrary sequences.
29+
# However, annotating it as such causes many problems.
30+
31+
_DoubleRead = Union[_Double[T], Tuple[StrOrPromise, Iterable["_DoubleRead[T]"]]]
32+
_DoubleCollector = List[Union[_Double[T], Tuple[StrOrPromise, "_DoubleCollector[T]"]]]
33+
_TripleCollector = List[Union[_Triple[T], Tuple[StrOrPromise, "_TripleCollector[T]"]]]
234

335

4-
class Choices:
36+
class Choices(Generic[T]):
537
"""
638
A class to encapsulate handy functionality for lists of choices
739
for a Django model field.
@@ -41,36 +73,60 @@ class Choices:
4173
4274
"""
4375

44-
def __init__(self, *choices):
76+
@overload
77+
def __init__(self: Choices[str], *choices: _ChoiceStr):
78+
...
79+
80+
@overload
81+
def __init__(self, *choices: _Choice[T]):
82+
...
83+
84+
def __init__(self, *choices: _ChoiceStr | _Choice[T]):
4585
# list of choices expanded to triples - can include optgroups
46-
self._triples = []
86+
self._triples: _TripleCollector[T] = []
4787
# list of choices as (db, human-readable) - can include optgroups
48-
self._doubles = []
88+
self._doubles: _DoubleCollector[T] = []
4989
# dictionary mapping db representation to human-readable
50-
self._display_map = {}
90+
self._display_map: dict[T, StrOrPromise | list[_Triple[T]]] = {}
5191
# dictionary mapping Python identifier to db representation
52-
self._identifier_map = {}
92+
self._identifier_map: dict[str, T] = {}
5393
# set of db representations
54-
self._db_values = set()
94+
self._db_values: set[T] = set()
5595

5696
self._process(choices)
5797

58-
def _store(self, triple, triple_collector, double_collector):
98+
def _store(
99+
self,
100+
triple: tuple[T, str, StrOrPromise],
101+
triple_collector: _TripleCollector[T],
102+
double_collector: _DoubleCollector[T]
103+
) -> None:
59104
self._identifier_map[triple[1]] = triple[0]
60105
self._display_map[triple[0]] = triple[2]
61106
self._db_values.add(triple[0])
62107
triple_collector.append(triple)
63108
double_collector.append((triple[0], triple[2]))
64109

65-
def _process(self, choices, triple_collector=None, double_collector=None):
110+
def _process(
111+
self,
112+
choices: Iterable[_ChoiceStr | _Choice[T]],
113+
triple_collector: _TripleCollector[T] | None = None,
114+
double_collector: _DoubleCollector[T] | None = None
115+
) -> None:
66116
if triple_collector is None:
67117
triple_collector = self._triples
68118
if double_collector is None:
69119
double_collector = self._doubles
70120

71-
store = lambda c: self._store(c, triple_collector, double_collector)
121+
def store(c: tuple[Any, str, StrOrPromise]) -> None:
122+
self._store(c, triple_collector, double_collector)
72123

73124
for choice in choices:
125+
# The type inference is not very accurate here:
126+
# - we lied in the type aliases, stating groups contain an arbitrary Sequence
127+
# rather than only list or tuple
128+
# - there is no way to express that _ChoiceStr is only used when T=str
129+
# - mypy 1.9.0 doesn't narrow types based on the value of len()
74130
if isinstance(choice, (list, tuple)):
75131
if len(choice) == 3:
76132
store(choice)
@@ -79,13 +135,13 @@ def _process(self, choices, triple_collector=None, double_collector=None):
79135
# option group
80136
group_name = choice[0]
81137
subchoices = choice[1]
82-
tc = []
138+
tc: _TripleCollector[T] = []
83139
triple_collector.append((group_name, tc))
84-
dc = []
140+
dc: _DoubleCollector[T] = []
85141
double_collector.append((group_name, dc))
86142
self._process(subchoices, tc, dc)
87143
else:
88-
store((choice[0], choice[0], choice[1]))
144+
store((choice[0], cast(str, choice[0]), cast('StrOrPromise', choice[1])))
89145
else:
90146
raise ValueError(
91147
"Choices can't take a list of length %s, only 2 or 3"
@@ -94,54 +150,74 @@ def _process(self, choices, triple_collector=None, double_collector=None):
94150
else:
95151
store((choice, choice, choice))
96152

97-
def __len__(self):
153+
def __len__(self) -> int:
98154
return len(self._doubles)
99155

100-
def __iter__(self):
156+
def __iter__(self) -> Iterator[_DoubleRead[T]]:
101157
return iter(self._doubles)
102158

103-
def __reversed__(self):
159+
def __reversed__(self) -> Iterator[_DoubleRead[T]]:
104160
return reversed(self._doubles)
105161

106-
def __getattr__(self, attname):
162+
def __getattr__(self, attname: str) -> T:
107163
try:
108164
return self._identifier_map[attname]
109165
except KeyError:
110166
raise AttributeError(attname)
111167

112-
def __getitem__(self, key):
168+
def __getitem__(self, key: T) -> StrOrPromise | Sequence[_Triple[T]]:
113169
return self._display_map[key]
114170

115-
def __add__(self, other):
171+
@overload
172+
def __add__(self: Choices[str], other: Choices[str] | Iterable[_ChoiceStr]) -> Choices[str]:
173+
...
174+
175+
@overload
176+
def __add__(self, other: Choices[T] | Iterable[_Choice[T]]) -> Choices[T]:
177+
...
178+
179+
def __add__(self, other: Choices[Any] | Iterable[_ChoiceStr | _Choice[Any]]) -> Choices[Any]:
180+
other_args: list[Any]
116181
if isinstance(other, self.__class__):
117-
other = other._triples
182+
other_args = other._triples
118183
else:
119-
other = list(other)
120-
return Choices(*(self._triples + other))
184+
other_args = list(other)
185+
return Choices(*(self._triples + other_args))
186+
187+
@overload
188+
def __radd__(self: Choices[str], other: Iterable[_ChoiceStr]) -> Choices[str]:
189+
...
190+
191+
@overload
192+
def __radd__(self, other: Iterable[_Choice[T]]) -> Choices[T]:
193+
...
121194

122-
def __radd__(self, other):
195+
def __radd__(self, other: Iterable[_ChoiceStr] | Iterable[_Choice[T]]) -> Choices[Any]:
123196
# radd is never called for matching types, so we don't check here
124-
other = list(other)
125-
return Choices(*(other + self._triples))
197+
other_args = list(other)
198+
# The exact type of 'other' depends on our type argument 'T', which
199+
# is expressed in the overloading, but lost within this method body.
200+
return Choices(*(other_args + self._triples)) # type: ignore[arg-type]
126201

127-
def __eq__(self, other):
202+
def __eq__(self, other: object) -> bool:
128203
if isinstance(other, self.__class__):
129204
return self._triples == other._triples
130205
return False
131206

132-
def __repr__(self):
207+
def __repr__(self) -> str:
133208
return '{}({})'.format(
134209
self.__class__.__name__,
135210
', '.join("%s" % repr(i) for i in self._triples)
136211
)
137212

138-
def __contains__(self, item):
213+
def __contains__(self, item: T) -> bool:
139214
return item in self._db_values
140215

141-
def __deepcopy__(self, memo):
142-
return self.__class__(*copy.deepcopy(self._triples, memo))
216+
def __deepcopy__(self, memo: dict[int, Any] | None) -> Choices[T]:
217+
args: list[Any] = copy.deepcopy(self._triples, memo)
218+
return self.__class__(*args)
143219

144-
def subset(self, *new_identifiers):
220+
def subset(self, *new_identifiers: str) -> Choices[T]:
145221
identifiers = set(self._identifier_map.keys())
146222

147223
if not identifiers.issuperset(new_identifiers):
@@ -150,7 +226,8 @@ def subset(self, *new_identifiers):
150226
identifiers.symmetric_difference(new_identifiers),
151227
)
152228

153-
return self.__class__(*[
229+
args: list[Any] = [
154230
choice for choice in self._triples
155231
if choice[1] in new_identifiers
156-
])
232+
]
233+
return self.__class__(*args)

0 commit comments

Comments
 (0)