Skip to content

Commit aa2f5d2

Browse files
committed
Feat: new query param to make sorting available.
1 parent 9724fb9 commit aa2f5d2

File tree

3 files changed

+296
-0
lines changed

3 files changed

+296
-0
lines changed

fastcrud/endpoint/endpoint_creator.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,9 @@ async def endpoint(
398398
items_per_page: Optional[int] = Query(
399399
None, alias="itemsPerPage", description="Number of items per page"
400400
),
401+
sort: Optional[str] = Query(
402+
None, description="Sort field(s) in format: 'field1,-field2' (prefix with - for descending)"
403+
),
401404
filters: dict = Depends(dynamic_filters),
402405
) -> Union[dict[str, Any], PaginatedListResponse, ListResponse]:
403406
is_paginated = (page is not None) or (items_per_page is not None)
@@ -419,20 +422,39 @@ async def endpoint(
419422
offset = 0
420423
limit = 100
421424

425+
# Parse sort parameter
426+
sort_columns = None
427+
sort_orders = None
428+
if sort:
429+
sort_fields = sort.split(',')
430+
sort_columns = []
431+
sort_orders = []
432+
for field in sort_fields:
433+
if field.startswith('-'):
434+
sort_columns.append(field[1:])
435+
sort_orders.append('desc')
436+
else:
437+
sort_columns.append(field)
438+
sort_orders.append('asc')
439+
422440
if self.select_schema is not None:
423441
crud_data = await self.crud.get_multi(
424442
db,
425443
offset=offset, # type: ignore
426444
limit=limit, # type: ignore
427445
schema_to_select=self.select_schema,
428446
return_as_model=True,
447+
sort_columns=sort_columns,
448+
sort_orders=sort_orders,
429449
**filters,
430450
)
431451
else:
432452
crud_data = await self.crud.get_multi(
433453
db,
434454
offset=offset, # type: ignore
435455
limit=limit, # type: ignore
456+
sort_columns=sort_columns,
457+
sort_orders=sort_orders,
436458
**filters,
437459
)
438460

@@ -652,6 +674,8 @@ def get_current_user(...):
652674
f"Read multiple {self.model.__name__} rows from the database.\n\n"
653675
f"- Use page & itemsPerPage for paginated results\n"
654676
f"- Use offset & limit for specific ranges\n"
677+
f"- Use sort parameter for sorting results (e.g., 'name' for ascending, '-name' for descending)\n"
678+
f"- Multiple sort fields can be specified with comma separation (e.g., 'name,-age')\n"
655679
f"- Returns paginated response when using page/itemsPerPage\n"
656680
f"- Returns simple list response when using offset/limit"
657681
),
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import pytest
2+
from fastapi import FastAPI
3+
from fastapi.testclient import TestClient
4+
5+
from fastcrud import EndpointCreator
6+
from fastcrud.endpoint.helper import FilterConfig
7+
from tests.sqlalchemy.conftest import ModelTest, CreateSchemaTest, UpdateSchemaTest
8+
9+
10+
@pytest.fixture
11+
def app(async_session):
12+
app = FastAPI()
13+
endpoint_creator = EndpointCreator(
14+
session=lambda: async_session,
15+
model=ModelTest,
16+
create_schema=CreateSchemaTest,
17+
update_schema=UpdateSchemaTest,
18+
path="/test",
19+
filter_config=FilterConfig(id=None, name=None, tier_id=None),
20+
)
21+
endpoint_creator.add_routes_to_router()
22+
app.include_router(endpoint_creator.router)
23+
return app
24+
25+
26+
@pytest.fixture
27+
def client(app):
28+
return TestClient(app)
29+
30+
31+
@pytest.mark.asyncio
32+
async def test_get_multi_with_sort_ascending(client, async_session, test_data):
33+
# Add test data
34+
for item in test_data:
35+
async_session.add(ModelTest(**item))
36+
await async_session.commit()
37+
38+
# Test ascending sort by name
39+
response = client.get("/test/?sort=name")
40+
print("Response:", response.status_code, response.json())
41+
assert response.status_code == 200
42+
43+
data = response.json()["data"]
44+
sorted_data = sorted(test_data, key=lambda x: x["name"])
45+
46+
assert len(data) == len(sorted_data)
47+
assert [item["name"] for item in data] == [item["name"] for item in sorted_data]
48+
49+
50+
@pytest.mark.asyncio
51+
async def test_get_multi_with_sort_descending(client, async_session, test_data):
52+
# Add test data
53+
for item in test_data:
54+
async_session.add(ModelTest(**item))
55+
await async_session.commit()
56+
57+
# Test descending sort by name
58+
response = client.get("/test/?sort=-name")
59+
assert response.status_code == 200
60+
61+
data = response.json()["data"]
62+
sorted_data = sorted(test_data, key=lambda x: x["name"], reverse=True)
63+
64+
assert len(data) == len(sorted_data)
65+
assert [item["name"] for item in data] == [item["name"] for item in sorted_data]
66+
67+
68+
@pytest.mark.asyncio
69+
async def test_get_multi_with_multiple_sort_fields(client, async_session, test_data):
70+
# Add test data
71+
for item in test_data:
72+
async_session.add(ModelTest(**item))
73+
await async_session.commit()
74+
75+
# Test multiple sort fields (tier_id ascending, name descending)
76+
response = client.get("/test/?sort=tier_id,-name")
77+
assert response.status_code == 200
78+
79+
data = response.json()["data"]
80+
81+
# Sort first by tier_id (ascending) then by name (descending)
82+
sorted_data = sorted(test_data, key=lambda x: (x["tier_id"], -ord(x["name"][0])))
83+
84+
assert len(data) == len(sorted_data)
85+
86+
# Group by tier_id and check that names are in descending order within each group
87+
tier_groups = {}
88+
for item in data:
89+
tier_id = item["tier_id"]
90+
if tier_id not in tier_groups:
91+
tier_groups[tier_id] = []
92+
tier_groups[tier_id].append(item["name"])
93+
94+
for tier_id, names in tier_groups.items():
95+
if len(names) > 1:
96+
for i in range(len(names) - 1):
97+
assert names[i] >= names[i + 1], f"Names in tier {tier_id} are not in descending order"
98+
99+
100+
@pytest.mark.asyncio
101+
async def test_get_multi_with_sort_and_pagination(client, async_session, test_data):
102+
# Add test data
103+
for item in test_data:
104+
async_session.add(ModelTest(**item))
105+
await async_session.commit()
106+
107+
# Test sorting with pagination
108+
response = client.get("/test/?sort=name&page=1&itemsPerPage=5")
109+
assert response.status_code == 200
110+
111+
data = response.json()["data"]
112+
sorted_data = sorted(test_data, key=lambda x: x["name"])[:5]
113+
114+
assert len(data) <= 5
115+
assert [item["name"] for item in data] == [item["name"] for item in sorted_data]
116+
117+
118+
@pytest.mark.asyncio
119+
async def test_get_multi_with_sort_and_filtering(client, async_session, test_data):
120+
# Add test data
121+
for item in test_data:
122+
async_session.add(ModelTest(**item))
123+
await async_session.commit()
124+
125+
# Test sorting with filtering
126+
tier_id_to_filter = 1
127+
response = client.get(f"/test/?sort=name&tier_id={tier_id_to_filter}")
128+
print("Response:", response.status_code, response.json())
129+
assert response.status_code == 200
130+
131+
data = response.json()["data"]
132+
filtered_data = [item for item in test_data if item["tier_id"] == tier_id_to_filter]
133+
sorted_filtered_data = sorted(filtered_data, key=lambda x: x["name"])
134+
135+
assert len(data) == len(sorted_filtered_data)
136+
assert all(item["tier_id"] == tier_id_to_filter for item in data)
137+
assert [item["name"] for item in data] == [item["name"] for item in sorted_filtered_data]
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import pytest
2+
from fastapi import FastAPI
3+
from fastapi.testclient import TestClient
4+
5+
from fastcrud import EndpointCreator
6+
from fastcrud.endpoint.helper import FilterConfig
7+
from tests.sqlmodel.conftest import ModelTest, CreateSchemaTest, UpdateSchemaTest
8+
9+
10+
@pytest.fixture
11+
def app(async_session):
12+
app = FastAPI()
13+
endpoint_creator = EndpointCreator(
14+
session=lambda: async_session,
15+
model=ModelTest,
16+
create_schema=CreateSchemaTest,
17+
update_schema=UpdateSchemaTest,
18+
path="/test",
19+
filter_config=FilterConfig(id=None, name=None, tier_id=None),
20+
)
21+
endpoint_creator.add_routes_to_router()
22+
app.include_router(endpoint_creator.router)
23+
return app
24+
25+
26+
@pytest.fixture
27+
def client(app):
28+
return TestClient(app)
29+
30+
31+
@pytest.mark.asyncio
32+
async def test_get_multi_with_sort_ascending(client, async_session, test_data):
33+
# Add test data
34+
for item in test_data:
35+
async_session.add(ModelTest(**item))
36+
await async_session.commit()
37+
38+
# Test ascending sort by name
39+
response = client.get("/test/?sort=name")
40+
assert response.status_code == 200
41+
42+
data = response.json()["data"]
43+
sorted_data = sorted(test_data, key=lambda x: x["name"])
44+
45+
assert len(data) == len(sorted_data)
46+
assert [item["name"] for item in data] == [item["name"] for item in sorted_data]
47+
48+
49+
@pytest.mark.asyncio
50+
async def test_get_multi_with_sort_descending(client, async_session, test_data):
51+
# Add test data
52+
for item in test_data:
53+
async_session.add(ModelTest(**item))
54+
await async_session.commit()
55+
56+
# Test descending sort by name
57+
response = client.get("/test/?sort=-name")
58+
assert response.status_code == 200
59+
60+
data = response.json()["data"]
61+
sorted_data = sorted(test_data, key=lambda x: x["name"], reverse=True)
62+
63+
assert len(data) == len(sorted_data)
64+
assert [item["name"] for item in data] == [item["name"] for item in sorted_data]
65+
66+
67+
@pytest.mark.asyncio
68+
async def test_get_multi_with_multiple_sort_fields(client, async_session, test_data):
69+
# Add test data
70+
for item in test_data:
71+
async_session.add(ModelTest(**item))
72+
await async_session.commit()
73+
74+
# Test multiple sort fields (tier_id ascending, name descending)
75+
response = client.get("/test/?sort=tier_id,-name")
76+
assert response.status_code == 200
77+
78+
data = response.json()["data"]
79+
80+
# Sort first by tier_id (ascending) then by name (descending)
81+
sorted_data = sorted(test_data, key=lambda x: (x["tier_id"], -ord(x["name"][0])))
82+
83+
assert len(data) == len(sorted_data)
84+
85+
# Group by tier_id and check that names are in descending order within each group
86+
tier_groups = {}
87+
for item in data:
88+
tier_id = item["tier_id"]
89+
if tier_id not in tier_groups:
90+
tier_groups[tier_id] = []
91+
tier_groups[tier_id].append(item["name"])
92+
93+
for tier_id, names in tier_groups.items():
94+
if len(names) > 1:
95+
for i in range(len(names) - 1):
96+
assert names[i] >= names[i + 1], f"Names in tier {tier_id} are not in descending order"
97+
98+
99+
@pytest.mark.asyncio
100+
async def test_get_multi_with_sort_and_pagination(client, async_session, test_data):
101+
# Add test data
102+
for item in test_data:
103+
async_session.add(ModelTest(**item))
104+
await async_session.commit()
105+
106+
# Test sorting with pagination
107+
response = client.get("/test/?sort=name&page=1&itemsPerPage=5")
108+
assert response.status_code == 200
109+
110+
data = response.json()["data"]
111+
sorted_data = sorted(test_data, key=lambda x: x["name"])[:5]
112+
113+
assert len(data) <= 5
114+
assert [item["name"] for item in data] == [item["name"] for item in sorted_data]
115+
116+
117+
@pytest.mark.asyncio
118+
async def test_get_multi_with_sort_and_filtering(client, async_session, test_data):
119+
# Add test data
120+
for item in test_data:
121+
async_session.add(ModelTest(**item))
122+
await async_session.commit()
123+
124+
# Test sorting with filtering
125+
tier_id_to_filter = 1
126+
response = client.get(f"/test/?sort=name")
127+
assert response.status_code == 200
128+
129+
data = response.json()["data"]
130+
filtered_response = [item for item in data if item["tier_id"] == tier_id_to_filter]
131+
filtered_data = [item for item in test_data if item["tier_id"] == tier_id_to_filter]
132+
sorted_filtered_data = sorted(filtered_data, key=lambda x: x["name"])
133+
assert len(filtered_response) == len(sorted_filtered_data)
134+
assert all(item["tier_id"] == tier_id_to_filter for item in filtered_response)
135+
assert [item["name"] for item in filtered_response] == [item["name"] for item in sorted_filtered_data]

0 commit comments

Comments
 (0)