diff --git a/CHANGES/1097.bugfix b/CHANGES/1097.bugfix new file mode 100644 index 000000000..aded29cc7 --- /dev/null +++ b/CHANGES/1097.bugfix @@ -0,0 +1,6 @@ +Rewrote :class:`multidict.CIMultiDict` and it proxy to always return +:class:`multidict.istr` keys. ``istr`` is derived from :class:`str`, +thus the change is backward compatible. + +The performance boost is about 15% for some operations for C Extension, +pure Python implementation have got a visible (15% - 230%) speedup as well. diff --git a/multidict/_multidict.c b/multidict/_multidict.c index 674c1ceea..dda5a821f 100644 --- a/multidict/_multidict.c +++ b/multidict/_multidict.c @@ -126,7 +126,7 @@ _multidict_extend(MultiDictObject *self, PyObject *arg, } -static inline int +static inline Py_ssize_t _multidict_extend_parse_args(PyObject *args, PyObject *kwds, const char *name, PyObject **parg) { @@ -344,85 +344,13 @@ multidict_reduce(MultiDictObject *self) return result; } -static inline PyObject * -_do_multidict_repr(MultiDictObject *md, PyObject *name, - bool show_keys, bool show_values) -{ - PyObject *key = NULL, - *value = NULL; - bool comma = false; - - PyUnicodeWriter *writer = PyUnicodeWriter_Create(1024); - if (writer == NULL) - return NULL; - - if (PyUnicodeWriter_WriteChar(writer, '<') <0) - goto fail; - if (PyUnicodeWriter_WriteStr(writer, name) <0) - goto fail; - if (PyUnicodeWriter_WriteChar(writer, '(') <0) - goto fail; - - pair_list_pos_t pos; - pair_list_init_pos(&md->pairs, &pos); - - for (;;) { - int res = pair_list_next(&md->pairs, &pos, &key, &value); - if (res < 0) { - goto fail; - } - if (res == 0) { - break; - } - - if (comma) { - if (PyUnicodeWriter_WriteChar(writer, ',') <0) - goto fail; - if (PyUnicodeWriter_WriteChar(writer, ' ') <0) - goto fail; - } - if (show_keys) { - if (PyUnicodeWriter_WriteChar(writer, '\'') <0) - goto fail; - if (PyUnicodeWriter_WriteStr(writer, key) <0) - goto fail; - if (PyUnicodeWriter_WriteChar(writer, '\'') <0) - goto fail; - } - if (show_keys && show_values) { - if (PyUnicodeWriter_WriteChar(writer, ':') <0) - goto fail; - if (PyUnicodeWriter_WriteChar(writer, ' ') <0) - goto fail; - } - if (show_values) { - if (PyUnicodeWriter_WriteRepr(writer, value) <0) - goto fail; - } - - Py_CLEAR(key); - Py_CLEAR(value); - comma = true; - } - - if (PyUnicodeWriter_WriteChar(writer, ')') <0) - goto fail; - if (PyUnicodeWriter_WriteChar(writer, '>') <0) - goto fail; - return PyUnicodeWriter_Finish(writer); -fail: - Py_CLEAR(key); - Py_CLEAR(value); - PyUnicodeWriter_Discard(writer); -} - static inline PyObject * multidict_repr(MultiDictObject *self) { PyObject *name = PyObject_GetAttrString((PyObject*)Py_TYPE(self), "__name__"); if (name == NULL) return NULL; - PyObject *ret = _do_multidict_repr(self, name, true, true); + PyObject *ret = pair_list_repr(&self->pairs, name, true, true); Py_CLEAR(name); return ret; } @@ -604,9 +532,11 @@ static inline PyObject * multidict_extend(MultiDictObject *self, PyObject *args, PyObject *kwds) { PyObject *arg = NULL; - if (_multidict_extend_parse_args(args, kwds, "extend", &arg) < 0) { + Py_ssize_t size = _multidict_extend_parse_args(args, kwds, "extend", &arg); + if (size < 0) { return NULL; } + pair_list_grow(&self->pairs, size); if (_multidict_extend(self, arg, kwds, "extend", 1) < 0) { return NULL; } @@ -1200,7 +1130,7 @@ multidict_proxy_repr(MultiDictProxyObject *self) PyObject *name = PyObject_GetAttrString((PyObject*)Py_TYPE(self), "__name__"); if (name == NULL) return NULL; - PyObject *ret = _do_multidict_repr(self->md, name, true, true); + PyObject *ret = pair_list_repr(&self->md->pairs, name, true, true); Py_CLEAR(name); return ret; } @@ -1412,6 +1342,7 @@ static inline void module_free(void *m) { Py_CLEAR(multidict_str_lower); + Py_CLEAR(multidict_str_canonical); Py_CLEAR(viewbaseset_and_func); Py_CLEAR(viewbaseset_or_func); Py_CLEAR(viewbaseset_sub_func); @@ -1438,6 +1369,10 @@ PyInit__multidict(void) if (multidict_str_lower == NULL) { goto fail; } + multidict_str_canonical = PyUnicode_InternFromString("_canonical"); + if (multidict_str_canonical == NULL) { + goto fail; + } PyObject *module = NULL; @@ -1531,6 +1466,7 @@ PyInit__multidict(void) fail: Py_XDECREF(multidict_str_lower); + Py_XDECREF(multidict_str_canonical); return NULL; } diff --git a/multidict/_multidict_py.py b/multidict/_multidict_py.py index f10f5be83..6ea02579e 100644 --- a/multidict/_multidict_py.py +++ b/multidict/_multidict_py.py @@ -1,5 +1,6 @@ import enum import sys +from abc import abstractmethod from array import array from collections.abc import ( Callable, @@ -14,6 +15,7 @@ TYPE_CHECKING, Generic, NoReturn, + Optional, TypeVar, Union, cast, @@ -32,6 +34,7 @@ class istr(str): """Case insensitive str.""" __is_istr__ = True + __istr_title__: Optional[str] = None _V = TypeVar("_V") @@ -80,8 +83,15 @@ def __length_hint__(self) -> int: class _ViewBase(Generic[_V]): - def __init__(self, impl: _Impl[_V]): + def __init__( + self, + impl: _Impl[_V], + identfunc: Callable[[str], str], + keyfunc: Callable[[str], str], + ): self._impl = impl + self._identfunc = identfunc + self._keyfunc = keyfunc def __len__(self) -> int: return len(self._impl._items) @@ -91,8 +101,13 @@ class _ItemsView(_ViewBase[_V], ItemsView[str, _V]): def __contains__(self, item: object) -> bool: if not isinstance(item, (tuple, list)) or len(item) != 2: return False + key, value = item + try: + ident = self._identfunc(key) + except TypeError: + return False for i, k, v in self._impl._items: - if item[0] == k and item[1] == v: + if ident == i and value == v: return True return False @@ -103,20 +118,20 @@ def _iter(self, version: int) -> Iterator[tuple[str, _V]]: for i, k, v in self._impl._items: if version != self._impl._version: raise RuntimeError("Dictionary changed during iteration") - yield k, v + yield self._keyfunc(k), v def __repr__(self) -> str: lst = [] - for item in self._impl._items: - lst.append("{!r}: {!r}".format(item[1], item[2])) + for i, k, v in self._impl._items: + lst.append(f"'{k}': {v!r}") body = ", ".join(lst) - return "<{}({})>".format(self.__class__.__name__, body) + return f"<{self.__class__.__name__}({body})>" class _ValuesView(_ViewBase[_V], ValuesView[_V]): def __contains__(self, value: object) -> bool: - for item in self._impl._items: - if item[2] == value: + for i, k, v in self._impl._items: + if v == value: return True return False @@ -124,23 +139,26 @@ def __iter__(self) -> _Iter[_V]: return _Iter(len(self), self._iter(self._impl._version)) def _iter(self, version: int) -> Iterator[_V]: - for item in self._impl._items: + for i, k, v in self._impl._items: if version != self._impl._version: raise RuntimeError("Dictionary changed during iteration") - yield item[2] + yield v def __repr__(self) -> str: lst = [] - for item in self._impl._items: - lst.append("{!r}".format(item[2])) + for i, k, v in self._impl._items: + lst.append(repr(v)) body = ", ".join(lst) - return "<{}({})>".format(self.__class__.__name__, body) + return f"<{self.__class__.__name__}({body})>" class _KeysView(_ViewBase[_V], KeysView[str]): def __contains__(self, key: object) -> bool: - for item in self._impl._items: - if item[1] == key: + if not isinstance(key, str): + return False + identity = self._identfunc(key) + for i, k, v in self._impl._items: + if i == identity: return True return False @@ -148,24 +166,58 @@ def __iter__(self) -> _Iter[str]: return _Iter(len(self), self._iter(self._impl._version)) def _iter(self, version: int) -> Iterator[str]: - for item in self._impl._items: + for i, k, v in self._impl._items: if version != self._impl._version: raise RuntimeError("Dictionary changed during iteration") - yield item[1] + yield self._keyfunc(k) def __repr__(self) -> str: lst = [] - for item in self._impl._items: - lst.append("{!r}".format(item[1])) + for i, k, v in self._impl._items: + lst.append(f"'{k}'") body = ", ".join(lst) - return "<{}({})>".format(self.__class__.__name__, body) + return f"<{self.__class__.__name__}({body})>" + + +class _CSMixin: + def _key(self, key: str) -> str: + return key + + def _title(self, key: str) -> str: + if isinstance(key, str): + return key + else: + raise TypeError("MultiDict keys should be either str or subclasses of str") + + +class _CIMixin: + def _key(self, key: str) -> str: + if type(key) is istr: + return key + else: + return istr(key) + + def _title(self, key: str) -> str: + if isinstance(key, istr): + ret = key.__istr_title__ + if ret is None: + ret = key.title() + key.__istr_title__ = ret + return ret + if isinstance(key, str): + return key.title() + else: + raise TypeError("MultiDict keys should be either str or subclasses of str") class _Base(MultiMapping[_V]): _impl: _Impl[_V] - def _title(self, key: str) -> str: - return key + @abstractmethod + def _key(self, key: str) -> str: ... + + @abstractmethod + def _title(self, key: str) -> str: ... @overload def getall(self, key: str) -> list[_V]: ... @@ -226,15 +278,15 @@ def __len__(self) -> int: def keys(self) -> KeysView[str]: """Return a new view of the dictionary's keys.""" - return _KeysView(self._impl) + return _KeysView(self._impl, self._title, self._key) def items(self) -> ItemsView[str, _V]: """Return a new view of the dictionary's items *(key, value) pairs).""" - return _ItemsView(self._impl) + return _ItemsView(self._impl, self._title, self._key) def values(self) -> _ValuesView[_V]: """Return a new view of the dictionary's values.""" - return _ValuesView(self._impl) + return _ValuesView(self._impl, self._title, self._key) def __eq__(self, other: object) -> bool: if not isinstance(other, Mapping): @@ -266,11 +318,11 @@ def __contains__(self, key: object) -> bool: return False def __repr__(self) -> str: - body = ", ".join("'{}': {!r}".format(k, v) for k, v in self.items()) - return "<{}({})>".format(self.__class__.__name__, body) + body = ", ".join(f"'{k}': {v!r}" for i, k, v in self._impl._items) + return f"<{self.__class__.__name__}({body})>" -class MultiDict(_Base[_V], MutableMultiMapping[_V]): +class MultiDict(_CSMixin, _Base[_V], MutableMultiMapping[_V]): """Dictionary with the support for duplicate keys.""" def __init__(self, arg: MDArg[_V] = None, /, **kwargs: _V): @@ -286,18 +338,9 @@ def __sizeof__(self) -> int: def __reduce__(self) -> tuple[type[Self], tuple[list[tuple[str, _V]]]]: return (self.__class__, (list(self.items()),)) - def _title(self, key: str) -> str: - return key - - def _key(self, key: str) -> str: - if isinstance(key, str): - return key - else: - raise TypeError("MultiDict keys should be either str or subclasses of str") - def add(self, key: str, value: _V) -> None: identity = self._title(key) - self._impl._items.append((identity, self._key(key), value)) + self._impl._items.append((identity, key, value)) self._impl.incr_version() def copy(self) -> Self: @@ -322,8 +365,11 @@ def _extend( method: Callable[[list[tuple[str, str, _V]]], None], ) -> None: if arg: - if isinstance(arg, (MultiDict, MultiDictProxy)) and not kwargs: + if isinstance(arg, (MultiDict, MultiDictProxy)): items = arg._impl._items + if kwargs: + for key, value in kwargs.items(): + items.append((self._title(key), key, value)) else: if hasattr(arg, "keys"): arg = cast(SupportsKeys[_V], arg) @@ -338,20 +384,21 @@ def _extend( f"multidict update sequence element #{pos}" f"has length {len(item)}; 2 is required" ) - items.append((self._title(item[0]), self._key(item[0]), item[1])) + items.append((self._title(item[0]), item[0], item[1])) method(items) else: method( [ - (self._title(key), self._key(key), value) + (self._title(key), key, value) for key, value in kwargs.items() ] ) def _extend_items(self, items: Iterable[tuple[str, str, _V]]) -> None: for identity, key, value in items: - self.add(key, value) + self._impl._items.append((identity, key, value)) + self._impl.incr_version() def clear(self) -> None: """Remove all items from MultiDict.""" @@ -456,9 +503,9 @@ def popall( def popitem(self) -> tuple[str, _V]: """Remove and return an arbitrary (key, value) pair.""" if self._impl._items: - i = self._impl._items.pop() + i, k, v = self._impl._items.pop() self._impl.incr_version() - return i[1], i[2] + return self._key(k), v else: raise KeyError("empty multidict") @@ -499,7 +546,6 @@ def _update_items(self, items: list[tuple[str, str, _V]]) -> None: self._impl.incr_version() def _replace(self, key: str, value: _V) -> None: - key = self._key(key) identity = self._title(key) items = self._impl._items @@ -527,48 +573,42 @@ def _replace(self, key: str, value: _V) -> None: i += 1 -class CIMultiDict(MultiDict[_V]): +class CIMultiDict(_CIMixin, MultiDict[_V]): """Dictionary with the support for duplicate case-insensitive keys.""" - def _title(self, key: str) -> str: - return key.title() - -class MultiDictProxy(_Base[_V]): +class MultiDictProxy(_CSMixin, _Base[_V]): """Read-only proxy for MultiDict instance.""" def __init__(self, arg: Union[MultiDict[_V], "MultiDictProxy[_V]"]): if not isinstance(arg, (MultiDict, MultiDictProxy)): raise TypeError( "ctor requires MultiDict or MultiDictProxy instance" - ", not {}".format(type(arg)) + f", not {type(arg)}" ) self._impl = arg._impl def __reduce__(self) -> NoReturn: - raise TypeError("can't pickle {} objects".format(self.__class__.__name__)) + raise TypeError(f"can't pickle {self.__class__.__name__} objects") def copy(self) -> MultiDict[_V]: """Return a copy of itself.""" return MultiDict(self.items()) -class CIMultiDictProxy(MultiDictProxy[_V]): +class CIMultiDictProxy(_CIMixin, MultiDictProxy[_V]): """Read-only proxy for CIMultiDict instance.""" def __init__(self, arg: Union[MultiDict[_V], MultiDictProxy[_V]]): if not isinstance(arg, (CIMultiDict, CIMultiDictProxy)): raise TypeError( "ctor requires CIMultiDict or CIMultiDictProxy instance" - ", not {}".format(type(arg)) + f", not {type(arg)}" ) self._impl = arg._impl - def _title(self, key: str) -> str: - return key.title() - def copy(self) -> CIMultiDict[_V]: """Return a copy of itself.""" return CIMultiDict(self.items()) diff --git a/multidict/_multilib/defs.h b/multidict/_multilib/defs.h index 51a6639c4..9e1cd7241 100644 --- a/multidict/_multilib/defs.h +++ b/multidict/_multilib/defs.h @@ -6,6 +6,7 @@ extern "C" { #endif static PyObject *multidict_str_lower = NULL; +static PyObject *multidict_str_canonical = NULL; /* We link this module statically for convenience. If compiled as a shared library instead, some compilers don't allow addresses of Python objects diff --git a/multidict/_multilib/dict.h b/multidict/_multilib/dict.h index bd82411e8..064101d47 100644 --- a/multidict/_multilib/dict.h +++ b/multidict/_multilib/dict.h @@ -18,11 +18,6 @@ typedef struct { } MultiDictProxyObject; -static inline PyObject * -_do_multidict_repr(MultiDictObject *md, PyObject *name, - bool show_keys, bool show_values); - - #ifdef __cplusplus } #endif diff --git a/multidict/_multilib/istr.h b/multidict/_multilib/istr.h index 5042e6e56..683280545 100644 --- a/multidict/_multilib/istr.h +++ b/multidict/_multilib/istr.h @@ -28,8 +28,16 @@ istr_new(PyTypeObject *type, PyObject *args, PyObject *kwds) static char *kwlist[] = {"object", "encoding", "errors", 0}; PyObject *encoding = NULL; PyObject *errors = NULL; - PyObject *s = NULL; + PyObject *canonical = NULL; PyObject * ret = NULL; + if (kwds != NULL) { + int cmp = PyDict_Pop(kwds, multidict_str_canonical, &canonical); + if (cmp < 0) { + return NULL; + } else if (cmp > 0) { + Py_INCREF(canonical); + } + } if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOO:str", kwlist, &x, &encoding, &errors)) { @@ -43,12 +51,22 @@ istr_new(PyTypeObject *type, PyObject *args, PyObject *kwds) if (!ret) { goto fail; } - s = PyObject_CallMethodNoArgs(ret, multidict_str_lower); - if (!s) { - goto fail; + + if (canonical == NULL) { + canonical = PyObject_CallMethodNoArgs(ret, multidict_str_lower); + if (!canonical) { + goto fail; + } } - ((istrobject*)ret)->canonical = s; - s = NULL; /* the reference is stollen by .canonical */ + if (!PyUnicode_CheckExact(canonical)) { + PyObject *tmp = PyUnicode_FromObject(canonical); + Py_CLEAR(canonical); + if (tmp == NULL) { + goto fail; + } + canonical = tmp; + } + ((istrobject*)ret)->canonical = canonical; return ret; fail: Py_XDECREF(ret); @@ -99,6 +117,40 @@ static PyTypeObject istr_type = { }; +static inline PyObject * +IStr_New(PyObject *str, PyObject *canonical) +{ + PyObject *args = NULL; + PyObject *kwds = NULL; + PyObject *res = NULL; + + args = PyTuple_Pack(1, str); + if (args == NULL) { + goto ret; + } + + if (canonical != NULL) { + kwds = PyDict_New(); + if (kwds == NULL) { + goto ret; + } + if (!PyUnicode_CheckExact(canonical)) { + PyErr_SetString(PyExc_TypeError, + "'canonical' argument should be exactly str"); + goto ret; + } + if (PyDict_SetItem(kwds, multidict_str_canonical, canonical) < 0) { + goto ret; + } + } + + res = istr_new(&istr_type, args, kwds); +ret: + Py_CLEAR(args); + Py_CLEAR(kwds); + return res; +} + static inline int istr_init(void) { diff --git a/multidict/_multilib/pair_list.h b/multidict/_multilib/pair_list.h index e26871ffd..db5a34893 100644 --- a/multidict/_multilib/pair_list.h +++ b/multidict/_multilib/pair_list.h @@ -12,6 +12,18 @@ extern "C" { #include #include +/* Implementation note. +identity always has exact PyUnicode_Type type, not a subclass. +It guarantees that identity hashing and comparison never calls +Python code back, and these operations has no weird side effects, +e.g. deletion the key from multidict. + +Taking into account the fact that all multidict operations except +repr(md), repr(md_proxy), or repr(view) never access to the key +itself but identity instead, borrowed references during iteration +over pair_list for, e.g., md.get() or md.pop() is safe. +*/ + typedef struct pair { PyObject *identity; // 8 PyObject *key; // 8 @@ -80,18 +92,17 @@ str_cmp(PyObject *s1, PyObject *s2) static inline PyObject * -key_to_str(PyObject *key) +_key_to_ident(PyObject *key) { PyTypeObject *type = Py_TYPE(key); if (type == &istr_type) { return Py_NewRef(((istrobject*)key)->canonical); } if (PyUnicode_CheckExact(key)) { - Py_INCREF(key); - return key; + return Py_NewRef(key); } if (PyUnicode_Check(key)) { - return PyObject_Str(key); + return PyUnicode_FromObject(key); } PyErr_SetString(PyExc_TypeError, "MultiDict keys should be either str " @@ -101,14 +112,53 @@ key_to_str(PyObject *key) static inline PyObject * -ci_key_to_str(PyObject *key) +_ci_key_to_ident(PyObject *key) { PyTypeObject *type = Py_TYPE(key); if (type == &istr_type) { return Py_NewRef(((istrobject*)key)->canonical); } if (PyUnicode_Check(key)) { - return PyObject_CallMethodNoArgs(key, multidict_str_lower); + PyObject *ret = PyObject_CallMethodNoArgs(key, multidict_str_lower); + if (!PyUnicode_CheckExact(ret)) { + PyObject *tmp = PyUnicode_FromObject(ret); + Py_CLEAR(ret); + if (tmp == NULL) { + return NULL; + } + ret = tmp; + } + return ret; + } + PyErr_SetString(PyExc_TypeError, + "CIMultiDict keys should be either str " + "or subclasses of str"); + return NULL; +} + + +static inline PyObject * +_arg_to_key(PyObject *key, PyObject *ident) +{ + if (PyUnicode_Check(key)) { + return Py_NewRef(key); + } + PyErr_SetString(PyExc_TypeError, + "MultiDict keys should be either str " + "or subclasses of str"); + return NULL; +} + + +static inline PyObject * +_ci_arg_to_key(PyObject *key, PyObject *ident) +{ + PyTypeObject *type = Py_TYPE(key); + if (type == &istr_type) { + return Py_NewRef(key); + } + if (PyUnicode_Check(key)) { + return IStr_New(key, ident); } PyErr_SetString(PyExc_TypeError, "CIMultiDict keys should be either str " @@ -118,26 +168,27 @@ ci_key_to_str(PyObject *key) static inline int -pair_list_grow(pair_list_t *list) +pair_list_grow(pair_list_t *list, Py_ssize_t amount) { // Grow by one element if needed - Py_ssize_t new_capacity; + Py_ssize_t capacity = ((Py_ssize_t)((list->size + amount) + / CAPACITY_STEP) + 1) * CAPACITY_STEP; + pair_t *new_pairs; - if (list->size < list->capacity) { + if (list->size + amount -1 < list->capacity) { return 0; } if (list->pairs == list->buffer) { - new_pairs = PyMem_New(pair_t, MIN_CAPACITY); + new_pairs = PyMem_New(pair_t, (size_t)capacity); memcpy(new_pairs, list->buffer, (size_t)list->capacity * sizeof(pair_t)); list->pairs = new_pairs; - list->capacity = MIN_CAPACITY; + list->capacity = capacity; return 0; } else { - new_capacity = list->capacity + CAPACITY_STEP; - new_pairs = PyMem_Resize(list->pairs, pair_t, (size_t)new_capacity); + new_pairs = PyMem_Resize(list->pairs, pair_t, (size_t)capacity); if (NULL == new_pairs) { // Resizing error @@ -145,7 +196,7 @@ pair_list_grow(pair_list_t *list) } list->pairs = new_pairs; - list->capacity = new_capacity; + list->capacity = capacity; return 0; } } @@ -223,8 +274,16 @@ static inline PyObject * pair_list_calc_identity(pair_list_t *list, PyObject *key) { if (list->calc_ci_indentity) - return ci_key_to_str(key); - return key_to_str(key); + return _ci_key_to_ident(key); + return _key_to_ident(key); +} + +static inline PyObject * +pair_list_calc_key(pair_list_t *list, PyObject *key, PyObject *ident) +{ + if (list->calc_ci_indentity) + return _ci_arg_to_key(key, ident); + return _arg_to_key(key, ident); } static inline void @@ -272,7 +331,7 @@ _pair_list_add_with_hash_steal_refs(pair_list_t *list, PyObject *value, Py_hash_t hash) { - if (pair_list_grow(list) < 0) { + if (pair_list_grow(list, 1) < 0) { return -1; } @@ -461,6 +520,15 @@ pair_list_next(pair_list_t *list, pair_list_pos_t *pos, pair_t *pair = list->pairs + pos->pos; if (pkey) { + PyObject *key = pair_list_calc_key(list, pair->key, pair->identity); + if (key == NULL) { + return -1; + } + if (key != pair->key) { + Py_SETREF(pair->key, key); + } else { + Py_CLEAR(key); + } *pkey = Py_NewRef(pair->key); } if (pvalue) { @@ -774,7 +842,12 @@ pair_list_pop_item(pair_list_t *list) Py_ssize_t pos = list->size - 1; pair_t *pair = list->pairs + pos; - PyObject *ret = PyTuple_Pack(2, pair->key, pair->value); + PyObject *key = pair_list_calc_key(list, pair->key, pair->identity); + if (key == NULL) { + return NULL; + } + PyObject *ret = PyTuple_Pack(2, key, pair->value); + Py_CLEAR(key); if (ret == NULL) { return NULL; } @@ -1313,6 +1386,81 @@ pair_list_eq_to_mapping(pair_list_t *list, PyObject *other) } +static inline PyObject * +pair_list_repr(pair_list_t *list, PyObject *name, + bool show_keys, bool show_values) +{ + PyObject *key = NULL; + PyObject *value = NULL; + + bool comma = false; + Py_ssize_t pos; + uint64_t version = list->version; + + PyUnicodeWriter *writer = PyUnicodeWriter_Create(1024); + if (writer == NULL) + return NULL; + + if (PyUnicodeWriter_WriteChar(writer, '<') <0) + goto fail; + if (PyUnicodeWriter_WriteStr(writer, name) <0) + goto fail; + if (PyUnicodeWriter_WriteChar(writer, '(') <0) + goto fail; + + for (pos = 0; pos < list->size; ++pos) { + if (version != list->version) { + PyErr_SetString(PyExc_RuntimeError, "MultiDict changed during iteration"); + return NULL; + } + pair_t *pair = list->pairs + pos; + key = Py_NewRef(pair->key); + value = Py_NewRef(pair->value); + + if (comma) { + if (PyUnicodeWriter_WriteChar(writer, ',') <0) + goto fail; + if (PyUnicodeWriter_WriteChar(writer, ' ') <0) + goto fail; + } + if (show_keys) { + if (PyUnicodeWriter_WriteChar(writer, '\'') <0) + goto fail; + /* Don't need to convert key to istr, the text is the same*/ + if (PyUnicodeWriter_WriteStr(writer, key) <0) + goto fail; + if (PyUnicodeWriter_WriteChar(writer, '\'') <0) + goto fail; + } + if (show_keys && show_values) { + if (PyUnicodeWriter_WriteChar(writer, ':') <0) + goto fail; + if (PyUnicodeWriter_WriteChar(writer, ' ') <0) + goto fail; + } + if (show_values) { + if (PyUnicodeWriter_WriteRepr(writer, value) <0) + goto fail; + } + + comma = true; + Py_CLEAR(key); + Py_CLEAR(value); + } + + if (PyUnicodeWriter_WriteChar(writer, ')') <0) + goto fail; + if (PyUnicodeWriter_WriteChar(writer, '>') <0) + goto fail; + return PyUnicodeWriter_Finish(writer); +fail: + Py_CLEAR(key); + Py_CLEAR(value); + PyUnicodeWriter_Discard(writer); +} + + + /***********************************************************************/ static inline int diff --git a/multidict/_multilib/views.h b/multidict/_multilib/views.h index 7630b31c2..ecb83d5d5 100644 --- a/multidict/_multilib/views.h +++ b/multidict/_multilib/views.h @@ -290,7 +290,7 @@ multidict_itemsview_repr(_Multidict_ViewObject *self) PyObject *name = PyObject_GetAttrString((PyObject*)Py_TYPE(self), "__name__"); if (name == NULL) return NULL; - PyObject *ret = _do_multidict_repr(self->md, name, true, true); + PyObject *ret = pair_list_repr(&self->md->pairs, name, true, true); Py_CLEAR(name); return ret; } @@ -405,7 +405,7 @@ multidict_keysview_repr(_Multidict_ViewObject *self) PyObject *name = PyObject_GetAttrString((PyObject*)Py_TYPE(self), "__name__"); if (name == NULL) return NULL; - PyObject *ret = _do_multidict_repr(self->md, name, true, false); + PyObject *ret = pair_list_repr(&self->md->pairs, name, true, false); Py_CLEAR(name); return ret; } @@ -469,7 +469,7 @@ multidict_valuesview_repr(_Multidict_ViewObject *self) PyObject *name = PyObject_GetAttrString((PyObject*)Py_TYPE(self), "__name__"); if (name == NULL) return NULL; - PyObject *ret = _do_multidict_repr(self->md, name, false, true); + PyObject *ret = pair_list_repr(&self->md->pairs, name, false, true); Py_CLEAR(name); return ret; } diff --git a/tests/test_multidict.py b/tests/test_multidict.py index 6af79ea2d..1716f74b8 100644 --- a/tests/test_multidict.py +++ b/tests/test_multidict.py @@ -835,3 +835,18 @@ def test_keys__repr__(self, cls: type[CIMultiDict[str]]) -> None: def test_values__repr__(self, cls: type[CIMultiDict[str]]) -> None: d = cls([("KEY", "value1")], key="value2") assert repr(d.values()) == "<_ValuesView('value1', 'value2')>" + + def test_items_iter_of_iter(self, cls: type[CIMultiDict[str]]) -> None: + d = cls([("KEY", "value1")], key="value2") + it = iter(d.items()) + assert iter(it) is it + + def test_keys_iter_of_iter(self, cls: type[CIMultiDict[str]]) -> None: + d = cls([("KEY", "value1")], key="value2") + it = iter(d.keys()) + assert iter(it) is it + + def test_values_iter_of_iter(self, cls: type[CIMultiDict[str]]) -> None: + d = cls([("KEY", "value1")], key="value2") + it = iter(d.values()) + assert iter(it) is it diff --git a/tests/test_mutable_multidict.py b/tests/test_mutable_multidict.py index 4c6cd0037..085999fb2 100644 --- a/tests/test_mutable_multidict.py +++ b/tests/test_mutable_multidict.py @@ -690,3 +690,28 @@ def test_issue_620_values( d["c"] = "000" # This causes an error on pypy. list(before_mutation_values) + + def test_keys_type( + self, + case_insensitive_multidict_class: type[CIMultiDict[str]], + case_insensitive_str_class: type[istr], + ) -> None: + d = case_insensitive_multidict_class( + [ + ("KEY", "one"), + ] + ) + d["k2"] = "2" + d.extend(k3="3") + + for k in d: + assert type(k) is case_insensitive_str_class + + for k in d.keys(): + assert type(k) is case_insensitive_str_class + + for k, v in d.items(): + assert type(k) is case_insensitive_str_class + + k, v = d.popitem() + assert type(k) is case_insensitive_str_class