Skip to content

Implement advanced filter configs #204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion fastcrud/endpoint/endpoint_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,18 @@ def __init__(

def _validate_filter_config(self, filter_config: FilterConfig) -> None:
model_columns = self.crud.model_col_names
supported_filters = self.crud._SUPPORTED_FILTERS
for key in filter_config.filters.keys():
if key not in model_columns:
if "__" in key:
field_name, op = key.rsplit("__", 1)
if op not in supported_filters:
raise ValueError(
f"Invalid filter op '{op}': following filter ops are allowed: {supported_filters.keys()}"
)
else:
field_name = key

if field_name not in model_columns:
raise ValueError(
f"Invalid filter column '{key}': not found in model '{self.model.__name__}' columns"
)
Expand Down
2 changes: 1 addition & 1 deletion tests/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def filtered_client(
create_schema=create_schema,
update_schema=update_schema,
delete_schema=delete_schema,
filter_config=FilterConfig(tier_id=None, name=None),
filter_config=FilterConfig(tier_id=None, name=None, name__startswith=None),
path="/test",
tags=["test"],
endpoint_names={
Expand Down
26 changes: 26 additions & 0 deletions tests/sqlalchemy/crud/test_get_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,29 @@ async def test_get_multi_handle_validation_error(async_session, test_model):
assert "Data validation error for schema CustomCreateSchemaTest:" in str(
exc_info.value
)


@pytest.mark.asyncio
async def test_read_items_with_advanced_filters(
async_session, test_model, test_data
):
for data in test_data:
new_item = test_model(**data)
async_session.add(new_item)
await async_session.commit()

crud = FastCRUD(test_model)

# Test startswith filter
name = "Ali"
result = await crud.get_multi(async_session, name__startswith=name)

assert len(result["data"]) > 0
for item in result["data"]:
assert item["name"].startswith(name)

# Test with non-matching filter
name = "Nothing"
result = await crud.get_multi(async_session, name__startswith=name)

assert len(result["data"]) == 0
2 changes: 1 addition & 1 deletion tests/sqlmodel/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def filtered_client(
create_schema=create_schema,
update_schema=update_schema,
delete_schema=delete_schema,
filter_config=FilterConfig(tier_id=None, name=None),
filter_config=FilterConfig(tier_id=None, name=None, name__startswith=None),
path="/test",
tags=["test"],
endpoint_names={
Expand Down
31 changes: 31 additions & 0 deletions tests/sqlmodel/endpoint/test_get_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,34 @@ async def test_read_items_with_schema(
assert len(data["data"]) > 0

assert all(read_schema.model_validate(item) for item in data["data"])


@pytest.mark.asyncio
async def test_read_items_with_advanced_filters(
filtered_client: TestClient, async_session, test_model, test_data
):
for data in test_data:
new_item = test_model(**data)
async_session.add(new_item)
await async_session.commit()

name = "Ali"
response = filtered_client.get(f"/test/get_multi?name__startswith={name}")

assert response.status_code == 200
data = response.json()

assert "data" in data
assert len(data["data"]) > 0

for item in data["data"]:
assert item["name"].startswith(name)

name = "Nothing"
response = filtered_client.get(f"/test/get_multi?name__startswith={name}")

assert response.status_code == 200
data = response.json()

assert "data" in data
assert len(data["data"]) == 0