1
1
import pytest
2
2
from fastapi import Depends , FastAPI
3
3
from fastapi .testclient import TestClient
4
- from sqlalchemy import update
5
- from sqlalchemy .ext .asyncio import AsyncSession
4
+ from sqlalchemy import Column , Integer
6
5
7
6
from fastcrud import crud_router , FilterConfig
8
- from fastcrud .endpoint .helper import _create_dynamic_filters
9
- from tests .sqlalchemy .conftest import ModelTest
7
+ from tests .sqlalchemy .conftest import ModelTest , TierModel
8
+
9
+
10
+ class ModelWithOrgTest (ModelTest ):
11
+ organization_id = Column (Integer , nullable = True , default = None )
10
12
11
13
12
14
class UserInfo :
@@ -21,27 +23,28 @@ async def get_auth_user():
21
23
async def get_org_id (auth : UserInfo = Depends (get_auth_user )):
22
24
return auth .organization_id
23
25
26
+ # Mock the get_org_id function to return a specific organization ID
27
+ async def mock_get_org_id (* args , ** kwargs ):
28
+ return 42 # This should match the organization_id we set for some test items
29
+
24
30
25
31
@pytest .fixture
26
32
def dependency_filtered_client (
27
- test_model , create_schema , update_schema , delete_schema , async_session , monkeypatch
33
+ test_model , create_schema , update_schema , delete_schema , async_session , monkeypatch
28
34
):
29
- # Mock the get_org_id function to return a specific organization ID
30
- async def mock_get_org_id (* args , ** kwargs ):
31
- return 42 # This should match the organization_id we set for some test items
32
-
33
35
monkeypatch .setattr ("tests.sqlalchemy.endpoint.test_dependency_filter.get_org_id" , mock_get_org_id )
34
36
35
37
app = FastAPI ()
36
38
39
+ # Include the router. Crucially, pass the session *function*, not the session itself.
37
40
app .include_router (
38
41
crud_router (
39
- session = lambda : async_session ,
40
- model = test_model ,
42
+ session = lambda : async_session , # Pass a *callable* that returns a session
43
+ model = ModelWithOrgTest ,
41
44
create_schema = create_schema ,
42
45
update_schema = update_schema ,
43
46
delete_schema = delete_schema ,
44
- filter_config = FilterConfig (organization_id = get_org_id , name = None ),
47
+ filter_config = FilterConfig (organization_id = get_org_id ),
45
48
path = "/test" ,
46
49
tags = ["test" ],
47
50
)
@@ -50,28 +53,13 @@ async def mock_get_org_id(*args, **kwargs):
50
53
return TestClient (app )
51
54
52
55
53
- def test_create_dynamic_filters_with_callable (test_model ):
54
- filter_config = FilterConfig (organization_id = get_org_id , name = None )
55
- column_types = {"organization_id" : int , "name" : str }
56
-
57
- filters_func = _create_dynamic_filters (filter_config , column_types )
58
-
59
- # Check that the function signature includes the dependency
60
- sig = filters_func .__signature__
61
- assert "organization_id" in sig .parameters
62
- assert hasattr (sig .parameters ["organization_id" ].default , "dependency" )
63
- assert sig .parameters ["organization_id" ].default .dependency == get_org_id
64
-
65
-
66
56
@pytest .mark .asyncio
67
57
async def test_dependency_filtered_endpoint (dependency_filtered_client , test_data , monkeypatch , async_session ):
68
-
69
58
# Create test data with different organization IDs
70
59
for i , item in enumerate (test_data ):
71
60
item ["organization_id" ] = 42 if i % 2 == 0 else 99
72
61
73
62
# Create a tier directly in the database
74
- from tests .sqlalchemy .conftest import TierModel
75
63
tier = TierModel (name = "Test Tier" )
76
64
async_session .add (tier )
77
65
await async_session .commit ()
@@ -80,7 +68,7 @@ async def test_dependency_filtered_endpoint(dependency_filtered_client, test_dat
80
68
81
69
# Create test items directly in the database
82
70
for i in range (10 ):
83
- test_item = ModelTest (
71
+ test_item = ModelWithOrgTest (
84
72
name = f"Test Item { i } " ,
85
73
tier_id = tier_id ,
86
74
organization_id = 42 if i < 5 else 99 # First 5 items have org_id=42, rest have org_id=99
@@ -89,7 +77,11 @@ async def test_dependency_filtered_endpoint(dependency_filtered_client, test_dat
89
77
await async_session .commit ()
90
78
91
79
# Get all items - should only return items with organization_id=42
92
- # For now, we'll just check that the endpoint exists and returns a response
93
- # The actual filtering will be tested in a more comprehensive integration test
94
- response = dependency_filtered_client .get ("/test" )
95
- assert response .status_code in (200 , 422 ) # 422 is acceptable if there are validation errors
80
+ # Add the required query parameters
81
+ response = dependency_filtered_client .get ("/test" , params = {"args" : "" , "kwargs" : "" })
82
+
83
+ assert response .status_code == 200
84
+ data = response .json ()["data" ]
85
+ assert len (data ) > 0
86
+ for item in data :
87
+ assert item ['organization_id' ] == 42
0 commit comments