Skip to content

Commit b1cec51

Browse files
authored
Ensure CursorPagination respects nulls in the ordering field (#8912)
* Ensure CursorPagination respects nulls in the ordering field * Lint * Fix pagination tests * Add test_ascending with nulls * Push tests for nulls * Test pass * Add comment * Fix test for django30
1 parent 62abf6a commit b1cec51

File tree

2 files changed

+142
-8
lines changed

2 files changed

+142
-8
lines changed

rest_framework/pagination.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from django.core.paginator import InvalidPage
1212
from django.core.paginator import Paginator as DjangoPaginator
13+
from django.db.models import Q
1314
from django.template import loader
1415
from django.utils.encoding import force_str
1516
from django.utils.translation import gettext_lazy as _
@@ -620,7 +621,7 @@ def paginate_queryset(self, queryset, request, view=None):
620621
queryset = queryset.order_by(*self.ordering)
621622

622623
# If we have a cursor with a fixed position then filter by that.
623-
if current_position is not None:
624+
if str(current_position) != 'None':
624625
order = self.ordering[0]
625626
is_reversed = order.startswith('-')
626627
order_attr = order.lstrip('-')
@@ -631,7 +632,12 @@ def paginate_queryset(self, queryset, request, view=None):
631632
else:
632633
kwargs = {order_attr + '__gt': current_position}
633634

634-
queryset = queryset.filter(**kwargs)
635+
filter_query = Q(**kwargs)
636+
# If some records contain a null for the ordering field, don't lose them.
637+
# When reverse ordering, nulls will come last and need to be included.
638+
if (reverse and not is_reversed) or is_reversed:
639+
filter_query |= Q(**{order_attr + '__isnull': True})
640+
queryset = queryset.filter(filter_query)
635641

636642
# If we have an offset cursor then offset the entire page by that amount.
637643
# We also always fetch an extra item in order to determine if there is a
@@ -704,7 +710,7 @@ def get_next_link(self):
704710
# The item in this position and the item following it
705711
# have different positions. We can use this position as
706712
# our marker.
707-
has_item_with_unique_position = True
713+
has_item_with_unique_position = position is not None
708714
break
709715

710716
# The item in this position has the same position as the item
@@ -757,7 +763,7 @@ def get_previous_link(self):
757763
# The item in this position and the item following it
758764
# have different positions. We can use this position as
759765
# our marker.
760-
has_item_with_unique_position = True
766+
has_item_with_unique_position = position is not None
761767
break
762768

763769
# The item in this position has the same position as the item
@@ -883,7 +889,7 @@ def _get_position_from_instance(self, instance, ordering):
883889
attr = instance[field_name]
884890
else:
885891
attr = getattr(instance, field_name)
886-
return str(attr)
892+
return None if attr is None else str(attr)
887893

888894
def get_paginated_response(self, data):
889895
return Response(OrderedDict([

tests/test_pagination.py

+131-3
Original file line numberDiff line numberDiff line change
@@ -951,17 +951,24 @@ class MockQuerySet:
951951
def __init__(self, items):
952952
self.items = items
953953

954-
def filter(self, created__gt=None, created__lt=None):
954+
def filter(self, q):
955+
q_args = dict(q.deconstruct()[1])
956+
if not q_args:
957+
# django 3.0.x artifact
958+
q_args = dict(q.deconstruct()[2])
959+
created__gt = q_args.get('created__gt')
960+
created__lt = q_args.get('created__lt')
961+
955962
if created__gt is not None:
956963
return MockQuerySet([
957964
item for item in self.items
958-
if item.created > int(created__gt)
965+
if item.created is None or item.created > int(created__gt)
959966
])
960967

961968
assert created__lt is not None
962969
return MockQuerySet([
963970
item for item in self.items
964-
if item.created < int(created__lt)
971+
if item.created is None or item.created < int(created__lt)
965972
])
966973

967974
def order_by(self, *ordering):
@@ -1080,6 +1087,127 @@ def get_pages(self, url):
10801087
return (previous, current, next, previous_url, next_url)
10811088

10821089

1090+
class NullableCursorPaginationModel(models.Model):
1091+
created = models.IntegerField(null=True)
1092+
1093+
1094+
class TestCursorPaginationWithNulls(TestCase):
1095+
"""
1096+
Unit tests for `pagination.CursorPagination` with ordering on a nullable field.
1097+
"""
1098+
1099+
def setUp(self):
1100+
class ExamplePagination(pagination.CursorPagination):
1101+
page_size = 1
1102+
ordering = 'created'
1103+
1104+
self.pagination = ExamplePagination()
1105+
data = [
1106+
None, None, 3, 4
1107+
]
1108+
for idx in data:
1109+
NullableCursorPaginationModel.objects.create(created=idx)
1110+
1111+
self.queryset = NullableCursorPaginationModel.objects.all()
1112+
1113+
get_pages = TestCursorPagination.get_pages
1114+
1115+
def test_ascending(self):
1116+
"""Test paginating one row at a time, current should go 1, 2, 3, 4, 3, 2, 1."""
1117+
(previous, current, next, previous_url, next_url) = self.get_pages('/')
1118+
1119+
assert previous is None
1120+
assert current == [None]
1121+
assert next == [None]
1122+
1123+
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
1124+
1125+
assert previous == [None]
1126+
assert current == [None]
1127+
assert next == [3]
1128+
1129+
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
1130+
1131+
assert previous == [3] # [None] paging artifact documented at https://github.com/ddelange/django-rest-framework/blob/3.14.0/rest_framework/pagination.py#L789
1132+
assert current == [3]
1133+
assert next == [4]
1134+
1135+
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
1136+
1137+
assert previous == [3]
1138+
assert current == [4]
1139+
assert next is None
1140+
assert next_url is None
1141+
1142+
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
1143+
1144+
assert previous == [None]
1145+
assert current == [3]
1146+
assert next == [4]
1147+
1148+
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
1149+
1150+
assert previous == [None]
1151+
assert current == [None]
1152+
assert next == [None] # [3] paging artifact documented at https://github.com/ddelange/django-rest-framework/blob/3.14.0/rest_framework/pagination.py#L731
1153+
1154+
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
1155+
1156+
assert previous is None
1157+
assert current == [None]
1158+
assert next == [None]
1159+
1160+
def test_descending(self):
1161+
"""Test paginating one row at a time, current should go 4, 3, 2, 1, 2, 3, 4."""
1162+
self.pagination.ordering = ('-created',)
1163+
(previous, current, next, previous_url, next_url) = self.get_pages('/')
1164+
1165+
assert previous is None
1166+
assert current == [4]
1167+
assert next == [3]
1168+
1169+
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
1170+
1171+
assert previous == [None] # [4] paging artifact
1172+
assert current == [3]
1173+
assert next == [None]
1174+
1175+
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
1176+
1177+
assert previous == [None] # [3] paging artifact
1178+
assert current == [None]
1179+
assert next == [None]
1180+
1181+
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
1182+
1183+
assert previous == [None]
1184+
assert current == [None]
1185+
assert next is None
1186+
assert next_url is None
1187+
1188+
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
1189+
1190+
assert previous == [3]
1191+
assert current == [None]
1192+
assert next == [None]
1193+
1194+
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
1195+
1196+
assert previous == [None]
1197+
assert current == [3]
1198+
assert next == [3] # [4] paging artifact documented at https://github.com/ddelange/django-rest-framework/blob/3.14.0/rest_framework/pagination.py#L731
1199+
1200+
# skip back artifact
1201+
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
1202+
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
1203+
1204+
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
1205+
1206+
assert previous is None
1207+
assert current == [4]
1208+
assert next == [3]
1209+
1210+
10831211
def test_get_displayed_page_numbers():
10841212
"""
10851213
Test our contextual page display function.

0 commit comments

Comments
 (0)