Skip to content

Commit 5d022e9

Browse files
asvetlovDreamsorcererbdraco
authored
Properly support set operations for case insensitive multidict views (#1038)
For #965 --------- Co-authored-by: Sam Bull <[email protected]> Co-authored-by: Sam Bull <[email protected]> Co-authored-by: J. Nick Koston <[email protected]>
1 parent 675c4ae commit 5d022e9

File tree

9 files changed

+2030
-313
lines changed

9 files changed

+2030
-313
lines changed

CHANGES/965.bugfix

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Set operations for ``KeysView`` and ``ItemsView`` of case-insensitive multidicts and their proxies are processed in case-insensitive manner.

multidict/_multidict.c

+1-5
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ multidict_mp_as_subscript(MultiDictObject *self, PyObject *key, PyObject *val)
380380
static inline int
381381
multidict_sq_contains(MultiDictObject *self, PyObject *key)
382382
{
383-
return pair_list_contains(&self->pairs, key);
383+
return pair_list_contains(&self->pairs, key, NULL);
384384
}
385385

386386
static inline PyObject *
@@ -1343,10 +1343,6 @@ module_free(void *m)
13431343
{
13441344
Py_CLEAR(multidict_str_lower);
13451345
Py_CLEAR(multidict_str_canonical);
1346-
Py_CLEAR(viewbaseset_and_func);
1347-
Py_CLEAR(viewbaseset_or_func);
1348-
Py_CLEAR(viewbaseset_sub_func);
1349-
Py_CLEAR(viewbaseset_xor_func);
13501346
}
13511347

13521348
static PyMethodDef multidict_module_methods[] = {

multidict/_multidict_base.py

-51
This file was deleted.

multidict/_multidict_py.py

+263-6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from typing import (
1515
TYPE_CHECKING,
16+
Any,
1617
Generic,
1718
NoReturn,
1819
Optional,
@@ -127,6 +128,149 @@ def __repr__(self) -> str:
127128
body = ", ".join(lst)
128129
return f"<{self.__class__.__name__}({body})>"
129130

131+
def _parse_item(
132+
self, arg: Union[tuple[str, _V], _T]
133+
) -> Optional[tuple[str, str, _V]]:
134+
if not isinstance(arg, tuple):
135+
return None
136+
if len(arg) != 2:
137+
return None
138+
try:
139+
return (self._identfunc(arg[0]), arg[0], arg[1])
140+
except TypeError:
141+
return None
142+
143+
def _tmp_set(self, it: Iterable[_T]) -> set[tuple[str, _V]]:
144+
tmp = set()
145+
for arg in it:
146+
item = self._parse_item(arg)
147+
if item is None:
148+
continue
149+
else:
150+
tmp.add((item[0], item[2]))
151+
return tmp
152+
153+
def __and__(self, other: Iterable[Any]) -> set[tuple[str, _V]]:
154+
ret = set()
155+
try:
156+
it = iter(other)
157+
except TypeError:
158+
return NotImplemented
159+
for arg in it:
160+
item = self._parse_item(arg)
161+
if item is None:
162+
continue
163+
identity, key, value = item
164+
for i, k, v in self._impl._items:
165+
if i == identity and v == value:
166+
ret.add((k, v))
167+
return ret
168+
169+
def __rand__(self, other: Iterable[_T]) -> set[_T]:
170+
ret = set()
171+
try:
172+
it = iter(other)
173+
except TypeError:
174+
return NotImplemented
175+
for arg in it:
176+
item = self._parse_item(arg)
177+
if item is None:
178+
continue
179+
identity, key, value = item
180+
for i, k, v in self._impl._items:
181+
if i == identity and v == value:
182+
ret.add(arg)
183+
break
184+
return ret
185+
186+
def __or__(self, other: Iterable[_T]) -> set[Union[tuple[str, _V], _T]]:
187+
ret: set[Union[tuple[str, _V], _T]] = set(self)
188+
try:
189+
it = iter(other)
190+
except TypeError:
191+
return NotImplemented
192+
for arg in it:
193+
item: Optional[tuple[str, str, _V]] = self._parse_item(arg)
194+
if item is None:
195+
ret.add(arg)
196+
continue
197+
identity, key, value = item
198+
for i, k, v in self._impl._items:
199+
if i == identity and v == value:
200+
break
201+
else:
202+
ret.add(arg)
203+
return ret
204+
205+
def __ror__(self, other: Iterable[_T]) -> set[Union[tuple[str, _V], _T]]:
206+
try:
207+
ret: set[Union[tuple[str, _V], _T]] = set(other)
208+
except TypeError:
209+
return NotImplemented
210+
tmp = self._tmp_set(ret)
211+
212+
for i, k, v in self._impl._items:
213+
if (i, v) not in tmp:
214+
ret.add((k, v))
215+
return ret
216+
217+
def __sub__(self, other: Iterable[_T]) -> set[Union[tuple[str, _V], _T]]:
218+
ret: set[Union[tuple[str, _V], _T]] = set()
219+
try:
220+
it = iter(other)
221+
except TypeError:
222+
return NotImplemented
223+
tmp = self._tmp_set(it)
224+
225+
for i, k, v in self._impl._items:
226+
if (i, v) not in tmp:
227+
ret.add((k, v))
228+
229+
return ret
230+
231+
def __rsub__(self, other: Iterable[_T]) -> set[_T]:
232+
ret: set[_T] = set()
233+
try:
234+
it = iter(other)
235+
except TypeError:
236+
return NotImplemented
237+
for arg in it:
238+
item = self._parse_item(arg)
239+
if item is None:
240+
ret.add(arg)
241+
continue
242+
243+
identity, key, value = item
244+
for i, k, v in self._impl._items:
245+
if i == identity and v == value:
246+
break
247+
else:
248+
ret.add(arg)
249+
return ret
250+
251+
def __xor__(self, other: Iterable[_T]) -> set[Union[tuple[str, _V], _T]]:
252+
try:
253+
rgt = set(other)
254+
except TypeError:
255+
return NotImplemented
256+
ret: set[Union[tuple[str, _V], _T]] = self - rgt
257+
ret |= (rgt - self)
258+
return ret
259+
260+
__rxor__ = __xor__
261+
262+
def isdisjoint(self, other: Iterable[tuple[str, _V]]) -> bool:
263+
for arg in other:
264+
item = self._parse_item(arg)
265+
if item is None:
266+
continue
267+
268+
identity, key, value = item
269+
for i, k, v in self._impl._items:
270+
if i == identity and v == value:
271+
return False
272+
return True
273+
130274

131275
class _ValuesView(_ViewBase[_V], ValuesView[_V]):
132276
def __contains__(self, value: object) -> bool:
@@ -178,6 +322,124 @@ def __repr__(self) -> str:
178322
body = ", ".join(lst)
179323
return f"<{self.__class__.__name__}({body})>"
180324

325+
def __and__(self, other: Iterable[object]) -> set[str]:
326+
ret = set()
327+
try:
328+
it = iter(other)
329+
except TypeError:
330+
return NotImplemented
331+
for key in it:
332+
if not isinstance(key, str):
333+
continue
334+
identity = self._identfunc(key)
335+
for i, k, v in self._impl._items:
336+
if i == identity:
337+
ret.add(k)
338+
return ret
339+
340+
def __rand__(self, other: Iterable[_T]) -> set[_T]:
341+
ret = set()
342+
try:
343+
it = iter(other)
344+
except TypeError:
345+
return NotImplemented
346+
for key in it:
347+
if not isinstance(key, str):
348+
continue
349+
identity = self._identfunc(key)
350+
for i, k, v in self._impl._items:
351+
if i == identity:
352+
ret.add(key)
353+
return cast(set[_T], ret)
354+
355+
def __or__(self, other: Iterable[_T]) -> set[Union[str, _T]]:
356+
ret: set[Union[str, _T]] = set(self)
357+
try:
358+
it = iter(other)
359+
except TypeError:
360+
return NotImplemented
361+
for key in it:
362+
if not isinstance(key, str):
363+
ret.add(key)
364+
continue
365+
identity = self._identfunc(key)
366+
for i, k, v in self._impl._items:
367+
if i == identity:
368+
break
369+
else:
370+
ret.add(key)
371+
return ret
372+
373+
def __ror__(self, other: Iterable[_T]) -> set[Union[str, _T]]:
374+
try:
375+
ret: set[Union[str, _T]] = set(other)
376+
except TypeError:
377+
return NotImplemented
378+
379+
tmp = set()
380+
for key in ret:
381+
if not isinstance(key, str):
382+
continue
383+
identity = self._identfunc(key)
384+
tmp.add(identity)
385+
386+
for i, k, v in self._impl._items:
387+
if i not in tmp:
388+
ret.add(k)
389+
return ret
390+
391+
def __sub__(self, other: Iterable[object]) -> set[str]:
392+
ret = set(self)
393+
try:
394+
it = iter(other)
395+
except TypeError:
396+
return NotImplemented
397+
for key in it:
398+
if not isinstance(key, str):
399+
continue
400+
identity = self._identfunc(key)
401+
for i, k, v in self._impl._items:
402+
if i == identity:
403+
ret.discard(k)
404+
break
405+
return ret
406+
407+
def __rsub__(self, other: Iterable[_T]) -> set[_T]:
408+
try:
409+
ret: set[_T] = set(other)
410+
except TypeError:
411+
return NotImplemented
412+
for key in other:
413+
if not isinstance(key, str):
414+
continue
415+
identity = self._identfunc(key)
416+
for i, k, v in self._impl._items:
417+
if i == identity:
418+
ret.discard(key) # type: ignore[arg-type]
419+
break
420+
return ret
421+
422+
def __xor__(self, other: Iterable[_T]) -> set[Union[str, _T]]:
423+
try:
424+
rgt = set(other)
425+
except TypeError:
426+
return NotImplemented
427+
ret: set[Union[str, _T]] = self - rgt # type: ignore[assignment]
428+
ret |= (rgt - self)
429+
return ret
430+
431+
__rxor__ = __xor__
432+
433+
def isdisjoint(self, other: Iterable[object]) -> bool:
434+
for key in other:
435+
if not isinstance(key, str):
436+
continue
437+
identity = self._identfunc(key)
438+
for i, k, v in self._impl._items:
439+
if i == identity:
440+
return False
441+
return True
442+
181443

182444
class _CSMixin:
183445
def _key(self, key: str) -> str:
@@ -388,12 +650,7 @@ def _extend(
388650

389651
method(items)
390652
else:
391-
method(
392-
[
393-
(self._title(key), key, value)
394-
for key, value in kwargs.items()
395-
]
396-
)
653+
method([(self._title(key), key, value) for key, value in kwargs.items()])
397654

398655
def _extend_items(self, items: Iterable[tuple[str, str, _V]]) -> None:
399656
for identity, key, value in items:

0 commit comments

Comments
 (0)