Skip to content

Commit 5126244

Browse files
authored
UUID support (#188)
* uuid support fix * new tests for uuid * code coverage
1 parent a2bb6fc commit 5126244

File tree

4 files changed

+462
-8
lines changed

4 files changed

+462
-8
lines changed

fastcrud/endpoint/helper.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import inspect
2+
from uuid import UUID
23
from typing import Optional, Union, Annotated, Sequence, Callable, TypeVar, Any
34

45
from pydantic import BaseModel, Field
56
from pydantic.functional_validators import field_validator
6-
from fastapi import Depends, Query, params
7+
from fastapi import Depends, Query, Path, params
78

89
from sqlalchemy import Column, inspect as sa_inspect
10+
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID
11+
from sqlalchemy.types import TypeEngine
912
from sqlalchemy.sql.elements import KeyedColumnElement
1013

1114
from fastcrud.types import ModelType
@@ -87,12 +90,36 @@ def _get_primary_keys(
8790
return primary_key_columns
8891

8992

93+
def _is_uuid_type(column_type: TypeEngine) -> bool: # pragma: no cover
94+
"""
95+
Check if a SQLAlchemy column type represents a UUID.
96+
Handles various SQL dialects and common UUID implementations.
97+
"""
98+
if isinstance(column_type, PostgresUUID):
99+
return True
100+
101+
type_name = getattr(column_type, "__visit_name__", "").lower()
102+
if "uuid" in type_name:
103+
return True
104+
105+
if hasattr(column_type, "impl"):
106+
return _is_uuid_type(column_type.impl)
107+
108+
return False
109+
110+
90111
def _get_python_type(column: Column) -> Optional[type]:
112+
"""Get the Python type for a SQLAlchemy column, with special handling for UUIDs."""
91113
try:
114+
if _is_uuid_type(column.type):
115+
return UUID
116+
92117
direct_type: Optional[type] = column.type.python_type
93118
return direct_type
94119
except NotImplementedError:
95120
if hasattr(column.type, "impl") and hasattr(column.type.impl, "python_type"):
121+
if _is_uuid_type(column.type.impl): # pragma: no cover
122+
return UUID
96123
indirect_type: Optional[type] = column.type.impl.python_type
97124
return indirect_type
98125
else: # pragma: no cover
@@ -110,7 +137,10 @@ def _get_column_types(
110137
raise ValueError("Model inspection failed, resulting in None.")
111138
column_types = {}
112139
for column in inspector_result.mapper.columns:
113-
column_types[column.name] = _get_python_type(column)
140+
column_type = _get_python_type(column)
141+
if hasattr(column.type, "__visit_name__") and column.type.__visit_name__ == "uuid":
142+
column_type = UUID
143+
column_types[column.name] = column_type
114144
return column_types
115145

116146

@@ -154,12 +184,24 @@ def wrapper(endpoint):
154184
for p in signature.parameters.values()
155185
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
156186
]
157-
extra_positional_params = [
158-
inspect.Parameter(
159-
name=k, annotation=v, kind=inspect.Parameter.POSITIONAL_ONLY
160-
)
161-
for k, v in pkeys.items()
162-
]
187+
extra_positional_params = []
188+
for k, v in pkeys.items():
189+
if v == UUID:
190+
extra_positional_params.append(
191+
inspect.Parameter(
192+
name=k,
193+
annotation=Annotated[UUID, Path(...)],
194+
kind=inspect.Parameter.POSITIONAL_ONLY
195+
)
196+
)
197+
else:
198+
extra_positional_params.append(
199+
inspect.Parameter(
200+
name=k,
201+
annotation=v,
202+
kind=inspect.Parameter.POSITIONAL_ONLY
203+
)
204+
)
163205

164206
endpoint.__signature__ = signature.replace(
165207
parameters=extra_positional_params + parameters

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ mypy = "^1.9.0"
4545
ruff = "^0.3.4"
4646
coverage = "^7.4.4"
4747
testcontainers = "^4.7.1"
48+
asyncpg = "^0.30.0"
49+
psycopg2-binary = "^2.9.10"
4850
psycopg = "^3.2.1"
4951
aiomysql = "^0.2.0"
5052
cryptography = "^43.0.1"

tests/sqlalchemy/core/test_uuid.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import pytest
2+
from uuid import UUID, uuid4
3+
4+
from sqlalchemy import Column, String
5+
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID
6+
from sqlalchemy.types import TypeDecorator
7+
from fastapi import FastAPI
8+
from fastapi.testclient import TestClient
9+
10+
from fastcrud import crud_router, FastCRUD
11+
from pydantic import BaseModel
12+
13+
from ..conftest import Base
14+
15+
16+
class UUIDType(TypeDecorator):
17+
"""Platform-independent UUID type.
18+
Uses PostgreSQL's UUID type, otherwise CHAR(36)
19+
"""
20+
21+
impl = String
22+
cache_ok = True
23+
24+
def __init__(self):
25+
super().__init__(36)
26+
27+
def load_dialect_impl(self, dialect):
28+
if dialect.name == "postgresql":
29+
return dialect.type_descriptor(PostgresUUID(as_uuid=True))
30+
else:
31+
return dialect.type_descriptor(String(36))
32+
33+
def process_bind_param(self, value, dialect):
34+
if value is None: # pragma: no cover
35+
return value
36+
elif dialect.name == "postgresql": # pragma: no cover
37+
return value
38+
else:
39+
return str(value)
40+
41+
def process_result_value(self, value, dialect):
42+
if value is None: # pragma: no cover
43+
return value
44+
if not isinstance(value, UUID):
45+
return UUID(value)
46+
return value # pragma: no cover
47+
48+
49+
class UUIDModel(Base):
50+
__tablename__ = "uuid_test"
51+
id = Column(UUIDType(), primary_key=True, default=uuid4)
52+
name = Column(String(255))
53+
54+
55+
class CustomUUID(TypeDecorator):
56+
"""Custom UUID type for testing."""
57+
58+
impl = String
59+
cache_ok = True
60+
61+
def __init__(self):
62+
super().__init__(36)
63+
self.__visit_name__ = "uuid"
64+
65+
def process_bind_param(self, value, dialect):
66+
if value is None: # pragma: no cover
67+
return value
68+
return str(value)
69+
70+
def process_result_value(self, value, dialect):
71+
if value is None: # pragma: no cover
72+
return value
73+
return UUID(value)
74+
75+
76+
class CustomUUIDModel(Base):
77+
__tablename__ = "custom_uuid_test"
78+
id = Column(CustomUUID(), primary_key=True, default=uuid4)
79+
name = Column(String(255))
80+
81+
82+
class UUIDSchema(BaseModel):
83+
id: UUID
84+
name: str
85+
86+
model_config = {"from_attributes": True}
87+
88+
89+
class CreateUUIDSchema(BaseModel):
90+
name: str
91+
92+
model_config = {"from_attributes": True}
93+
94+
95+
class UpdateUUIDSchema(BaseModel):
96+
name: str
97+
98+
model_config = {"from_attributes": True}
99+
100+
101+
@pytest.fixture
102+
def uuid_client(async_session):
103+
app = FastAPI()
104+
105+
app.include_router(
106+
crud_router(
107+
session=lambda: async_session,
108+
model=UUIDModel,
109+
crud=FastCRUD(UUIDModel),
110+
create_schema=CreateUUIDSchema,
111+
update_schema=UpdateUUIDSchema,
112+
path="/uuid-test",
113+
tags=["uuid-test"],
114+
endpoint_names={
115+
"create": "create",
116+
"read": "get",
117+
"update": "update",
118+
"delete": "delete",
119+
"read_multi": "get_multi",
120+
},
121+
)
122+
)
123+
124+
app.include_router(
125+
crud_router(
126+
session=lambda: async_session,
127+
model=CustomUUIDModel,
128+
crud=FastCRUD(CustomUUIDModel),
129+
create_schema=CreateUUIDSchema,
130+
update_schema=UpdateUUIDSchema,
131+
path="/custom-uuid-test",
132+
tags=["custom-uuid-test"],
133+
endpoint_names={
134+
"create": "create",
135+
"read": "get",
136+
"update": "update",
137+
"delete": "delete",
138+
"read_multi": "get_multi",
139+
},
140+
)
141+
)
142+
143+
return TestClient(app)
144+
145+
146+
@pytest.mark.asyncio
147+
@pytest.mark.dialect("sqlite")
148+
async def test_custom_uuid_crud(uuid_client):
149+
response = uuid_client.post("/custom-uuid-test/create", json={"name": "test"})
150+
assert (
151+
response.status_code == 200
152+
), f"Creation failed with response: {response.text}"
153+
154+
try:
155+
data = response.json()
156+
assert "id" in data, f"Response does not contain 'id': {data}"
157+
uuid_id = data["id"]
158+
except Exception as e: # pragma: no cover
159+
pytest.fail(f"Failed to process response: {response.text}. Error: {str(e)}")
160+
161+
try:
162+
UUID(uuid_id)
163+
except ValueError: # pragma: no cover
164+
pytest.fail("Invalid UUID format")
165+
166+
response = uuid_client.get(f"/custom-uuid-test/get/{uuid_id}")
167+
assert response.status_code == 200
168+
assert response.json()["id"] == uuid_id
169+
assert response.json()["name"] == "test"
170+
171+
update_response = uuid_client.patch(
172+
f"/custom-uuid-test/update/{uuid_id}", json={"name": "updated"}
173+
)
174+
response = uuid_client.get(f"/custom-uuid-test/get/{uuid_id}")
175+
176+
assert update_response.status_code == 200
177+
assert response.status_code == 200
178+
assert response.json()["name"] == "updated"
179+
180+
response = uuid_client.delete(f"/custom-uuid-test/delete/{uuid_id}")
181+
assert response.status_code == 200
182+
183+
response = uuid_client.get(f"/custom-uuid-test/get/{uuid_id}")
184+
assert response.status_code == 404
185+
186+
187+
@pytest.mark.asyncio
188+
async def test_uuid_list_endpoint(uuid_client):
189+
created_ids = []
190+
for i in range(3):
191+
response = uuid_client.post("/uuid-test/create", json={"name": f"test_{i}"})
192+
assert response.status_code == 200
193+
created_ids.append(response.json()["id"])
194+
195+
response = uuid_client.get("/uuid-test/get_multi")
196+
assert response.status_code == 200
197+
data = response.json()["data"]
198+
assert len(data) == 3
199+
200+
for item in data:
201+
try:
202+
UUID(item["id"])
203+
except ValueError: # pragma: no cover
204+
pytest.fail("Invalid UUID format in list response")

0 commit comments

Comments
 (0)