43
43
from sqlalchemy .exc import CompileError
44
44
from sqlalchemy .sql .elements import TextClause
45
45
from sqlalchemy .sql .type_api import UserDefinedType
46
+ from sqlalchemy .types import TypeEngine
46
47
47
48
from .models import (
48
49
ColumnAttribute ,
63
64
uses_default_name ,
64
65
)
65
66
66
- if sys .version_info < (3 , 10 ):
67
- pass
68
- else :
69
- pass
70
-
71
67
_re_boolean_check_constraint = re .compile (r"(?:.*?\.)?(.*?) IN \(0, 1\)" )
72
68
_re_column_name = re .compile (r'(?:(["`]?).*\1\.)?(["`]?)(.*)\2' )
73
69
_re_enum_check_constraint = re .compile (r"(?:.*?\.)?(.*?) IN \((.+)\)" )
@@ -1201,22 +1197,40 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
1201
1197
column = column_attr .column
1202
1198
rendered_column = self .render_column (column , column_attr .name != column .name )
1203
1199
1204
- try :
1205
- python_type = column .type .python_type
1200
+ def get_type_qualifiers () -> tuple [str , TypeEngine [Any ], str ]:
1201
+ column_type = column .type
1202
+ pre : list [str ] = []
1203
+ post_size = 0
1204
+ if column .nullable :
1205
+ self .add_literal_import ("typing" , "Optional" )
1206
+ pre .append ("Optional[" )
1207
+ post_size += 1
1208
+
1209
+ if isinstance (column_type , ARRAY ):
1210
+ dim = getattr (column_type , "dimensions" , None ) or 1
1211
+ pre .extend ("list[" for _ in range (dim ))
1212
+ post_size += dim
1213
+
1214
+ column_type = column_type .item_type
1215
+
1216
+ return "" .join (pre ), column_type , "]" * post_size
1217
+
1218
+ def render_python_type (column_type : TypeEngine [Any ]) -> str :
1219
+ python_type = column_type .python_type
1206
1220
python_type_name = python_type .__name__
1207
- if python_type .__module__ == "builtins" :
1208
- column_python_type = python_type_name
1209
- else :
1210
- python_type_module = python_type . __module__
1211
- column_python_type = f" { python_type_module } . { python_type_name } "
1221
+ python_type_module = python_type .__module__
1222
+ if python_type_module == "builtins" :
1223
+ return python_type_name
1224
+
1225
+ try :
1212
1226
self .add_module_import (python_type_module )
1213
- except NotImplementedError :
1214
- self .add_literal_import ("typing" , "Any" )
1215
- column_python_type = "Any"
1227
+ return f"{ python_type_module } .{ python_type_name } "
1228
+ except NotImplementedError :
1229
+ self .add_literal_import ("typing" , "Any" )
1230
+ return "Any"
1216
1231
1217
- if column .nullable :
1218
- self .add_literal_import ("typing" , "Optional" )
1219
- column_python_type = f"Optional[{ column_python_type } ]"
1232
+ pre , col_type , post = get_type_qualifiers ()
1233
+ column_python_type = f"{ pre } { render_python_type (col_type )} { post } "
1220
1234
return f"{ column_attr .name } : Mapped[{ column_python_type } ] = { rendered_column } "
1221
1235
1222
1236
def render_relationship (self , relationship : RelationshipAttribute ) -> str :
@@ -1297,8 +1311,7 @@ def render_join(terms: list[JoinType]) -> str:
1297
1311
1298
1312
relationship_type : str
1299
1313
if relationship .type == RelationshipType .ONE_TO_MANY :
1300
- self .add_literal_import ("typing" , "List" )
1301
- relationship_type = f"List['{ relationship .target .name } ']"
1314
+ relationship_type = f"list['{ relationship .target .name } ']"
1302
1315
elif relationship .type in (
1303
1316
RelationshipType .ONE_TO_ONE ,
1304
1317
RelationshipType .MANY_TO_ONE ,
@@ -1310,8 +1323,7 @@ def render_join(terms: list[JoinType]) -> str:
1310
1323
self .add_literal_import ("typing" , "Optional" )
1311
1324
relationship_type = f"Optional[{ relationship_type } ]"
1312
1325
elif relationship .type == RelationshipType .MANY_TO_MANY :
1313
- self .add_literal_import ("typing" , "List" )
1314
- relationship_type = f"List['{ relationship .target .name } ']"
1326
+ relationship_type = f"list['{ relationship .target .name } ']"
1315
1327
else :
1316
1328
self .add_literal_import ("typing" , "Any" )
1317
1329
relationship_type = "Any"
@@ -1409,13 +1421,6 @@ def collect_imports_for_model(self, model: Model) -> None:
1409
1421
if model .relationships :
1410
1422
self .add_literal_import ("sqlmodel" , "Relationship" )
1411
1423
1412
- for relationship_attr in model .relationships :
1413
- if relationship_attr .type in (
1414
- RelationshipType .ONE_TO_MANY ,
1415
- RelationshipType .MANY_TO_MANY ,
1416
- ):
1417
- self .add_literal_import ("typing" , "List" )
1418
-
1419
1424
def collect_imports_for_column (self , column : Column [Any ]) -> None :
1420
1425
super ().collect_imports_for_column (column )
1421
1426
try :
@@ -1487,8 +1492,7 @@ def render_relationship(self, relationship: RelationshipAttribute) -> str:
1487
1492
RelationshipType .ONE_TO_MANY ,
1488
1493
RelationshipType .MANY_TO_MANY ,
1489
1494
):
1490
- self .add_literal_import ("typing" , "List" )
1491
- annotation = f"List[{ annotation } ]"
1495
+ annotation = f"list[{ annotation } ]"
1492
1496
else :
1493
1497
self .add_literal_import ("typing" , "Optional" )
1494
1498
annotation = f"Optional[{ annotation } ]"
0 commit comments