Skip to content

Commit 82414b0

Browse files
committed
Refactor tests and models for organization-based filtering
Removed unused `organization_id` column from `ModelTest` and introduced `ModelWithOrgTest` to handle organization filtering. Updated dependency filtering tests to ensure proper filtering by `organization_id` and improved clarity of fixtures and endpoint testing logic.
1 parent 33faf33 commit 82414b0

File tree

3 files changed

+25
-34
lines changed

3 files changed

+25
-34
lines changed

tests/sqlalchemy/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ class ModelTest(Base):
5757
category_id = Column(
5858
Integer, ForeignKey("category.id"), nullable=True, default=None
5959
)
60-
organization_id = Column(Integer, nullable=True, default=None)
6160
tier = relationship("TierModel", back_populates="tests")
6261
category = relationship("CategoryModel", back_populates="tests")
6362
multi_pk = relationship("MultiPkModel", back_populates="test")

tests/sqlalchemy/core/test_dependency_filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import pytest
21
from fastapi import Depends
2+
33
from fastcrud.endpoint.helper import FilterConfig
44

55

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import pytest
22
from fastapi import Depends, FastAPI
33
from fastapi.testclient import TestClient
4-
from sqlalchemy import update
5-
from sqlalchemy.ext.asyncio import AsyncSession
4+
from sqlalchemy import Column, Integer
65

76
from fastcrud import crud_router, FilterConfig
8-
from fastcrud.endpoint.helper import _create_dynamic_filters
9-
from tests.sqlalchemy.conftest import ModelTest
7+
from tests.sqlalchemy.conftest import ModelTest, TierModel
8+
9+
10+
class ModelWithOrgTest(ModelTest):
11+
organization_id = Column(Integer, nullable=True, default=None)
1012

1113

1214
class UserInfo:
@@ -21,27 +23,28 @@ async def get_auth_user():
2123
async def get_org_id(auth: UserInfo = Depends(get_auth_user)):
2224
return auth.organization_id
2325

26+
# Mock the get_org_id function to return a specific organization ID
27+
async def mock_get_org_id(*args, **kwargs):
28+
return 42 # This should match the organization_id we set for some test items
29+
2430

2531
@pytest.fixture
2632
def dependency_filtered_client(
27-
test_model, create_schema, update_schema, delete_schema, async_session, monkeypatch
33+
test_model, create_schema, update_schema, delete_schema, async_session, monkeypatch
2834
):
29-
# Mock the get_org_id function to return a specific organization ID
30-
async def mock_get_org_id(*args, **kwargs):
31-
return 42 # This should match the organization_id we set for some test items
32-
3335
monkeypatch.setattr("tests.sqlalchemy.endpoint.test_dependency_filter.get_org_id", mock_get_org_id)
3436

3537
app = FastAPI()
3638

39+
# Include the router. Crucially, pass the session *function*, not the session itself.
3740
app.include_router(
3841
crud_router(
39-
session=lambda: async_session,
40-
model=test_model,
42+
session=lambda: async_session, # Pass a *callable* that returns a session
43+
model=ModelWithOrgTest,
4144
create_schema=create_schema,
4245
update_schema=update_schema,
4346
delete_schema=delete_schema,
44-
filter_config=FilterConfig(organization_id=get_org_id, name=None),
47+
filter_config=FilterConfig(organization_id=get_org_id),
4548
path="/test",
4649
tags=["test"],
4750
)
@@ -50,28 +53,13 @@ async def mock_get_org_id(*args, **kwargs):
5053
return TestClient(app)
5154

5255

53-
def test_create_dynamic_filters_with_callable(test_model):
54-
filter_config = FilterConfig(organization_id=get_org_id, name=None)
55-
column_types = {"organization_id": int, "name": str}
56-
57-
filters_func = _create_dynamic_filters(filter_config, column_types)
58-
59-
# Check that the function signature includes the dependency
60-
sig = filters_func.__signature__
61-
assert "organization_id" in sig.parameters
62-
assert hasattr(sig.parameters["organization_id"].default, "dependency")
63-
assert sig.parameters["organization_id"].default.dependency == get_org_id
64-
65-
6656
@pytest.mark.asyncio
6757
async def test_dependency_filtered_endpoint(dependency_filtered_client, test_data, monkeypatch, async_session):
68-
6958
# Create test data with different organization IDs
7059
for i, item in enumerate(test_data):
7160
item["organization_id"] = 42 if i % 2 == 0 else 99
7261

7362
# Create a tier directly in the database
74-
from tests.sqlalchemy.conftest import TierModel
7563
tier = TierModel(name="Test Tier")
7664
async_session.add(tier)
7765
await async_session.commit()
@@ -80,7 +68,7 @@ async def test_dependency_filtered_endpoint(dependency_filtered_client, test_dat
8068

8169
# Create test items directly in the database
8270
for i in range(10):
83-
test_item = ModelTest(
71+
test_item = ModelWithOrgTest(
8472
name=f"Test Item {i}",
8573
tier_id=tier_id,
8674
organization_id=42 if i < 5 else 99 # First 5 items have org_id=42, rest have org_id=99
@@ -89,7 +77,11 @@ async def test_dependency_filtered_endpoint(dependency_filtered_client, test_dat
8977
await async_session.commit()
9078

9179
# Get all items - should only return items with organization_id=42
92-
# For now, we'll just check that the endpoint exists and returns a response
93-
# The actual filtering will be tested in a more comprehensive integration test
94-
response = dependency_filtered_client.get("/test")
95-
assert response.status_code in (200, 422) # 422 is acceptable if there are validation errors
80+
# Add the required query parameters
81+
response = dependency_filtered_client.get("/test", params={"args": "", "kwargs": ""})
82+
83+
assert response.status_code == 200
84+
data = response.json()["data"]
85+
assert len(data) > 0
86+
for item in data:
87+
assert item['organization_id'] == 42

0 commit comments

Comments
 (0)