Skip to content

Commit b60293f

Browse files
JoaquimEstevespre-commit-ci[bot]agronholm
authored
πŸ› Handles incomplete list-type (#385)
* πŸ› Handles incomplete list-type * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * βœ‚οΈ Removed old-style `typing.List` in favour of just list. Also fixed the pre-commit * 🧹 Simplified getting column-type. No more recursive func * βœ‚οΈ Got rid of the `cast` by using isinstance * βœ‚οΈ Removed .swp files. OOps * 🧹 Added `get_type_qualifiers` function * Added a blank line after a control block. Co-authored-by: Alex GrΓΆnholm <[email protected]> * Update tests/test_generator_declarative.py * 🧹 Renamed function * Formatting/style fixes * Reverted rename * πŸ“ Updated CHANGES.rst * πŸ“ Tweaked CHANGES comment following feedback --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alex GrΓΆnholm <[email protected]>
1 parent 41d1469 commit b60293f

File tree

5 files changed

+109
-79
lines changed

5 files changed

+109
-79
lines changed

β€ŽCHANGES.rst

+5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Version history
22
===============
33

4+
**UNRELEASED**
5+
6+
- Type annotations for ARRAY column attributes now include the Python type of
7+
the array elements
8+
49
**3.0.0**
510

611
- Dropped support for Python 3.8

β€Žsrc/sqlacodegen/generators.py

+35-31
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from sqlalchemy.exc import CompileError
4444
from sqlalchemy.sql.elements import TextClause
4545
from sqlalchemy.sql.type_api import UserDefinedType
46+
from sqlalchemy.types import TypeEngine
4647

4748
from .models import (
4849
ColumnAttribute,
@@ -63,11 +64,6 @@
6364
uses_default_name,
6465
)
6566

66-
if sys.version_info < (3, 10):
67-
pass
68-
else:
69-
pass
70-
7167
_re_boolean_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \(0, 1\)")
7268
_re_column_name = re.compile(r'(?:(["`]?).*\1\.)?(["`]?)(.*)\2')
7369
_re_enum_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \((.+)\)")
@@ -1201,22 +1197,40 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
12011197
column = column_attr.column
12021198
rendered_column = self.render_column(column, column_attr.name != column.name)
12031199

1204-
try:
1205-
python_type = column.type.python_type
1200+
def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
1201+
column_type = column.type
1202+
pre: list[str] = []
1203+
post_size = 0
1204+
if column.nullable:
1205+
self.add_literal_import("typing", "Optional")
1206+
pre.append("Optional[")
1207+
post_size += 1
1208+
1209+
if isinstance(column_type, ARRAY):
1210+
dim = getattr(column_type, "dimensions", None) or 1
1211+
pre.extend("list[" for _ in range(dim))
1212+
post_size += dim
1213+
1214+
column_type = column_type.item_type
1215+
1216+
return "".join(pre), column_type, "]" * post_size
1217+
1218+
def render_python_type(column_type: TypeEngine[Any]) -> str:
1219+
python_type = column_type.python_type
12061220
python_type_name = python_type.__name__
1207-
if python_type.__module__ == "builtins":
1208-
column_python_type = python_type_name
1209-
else:
1210-
python_type_module = python_type.__module__
1211-
column_python_type = f"{python_type_module}.{python_type_name}"
1221+
python_type_module = python_type.__module__
1222+
if python_type_module == "builtins":
1223+
return python_type_name
1224+
1225+
try:
12121226
self.add_module_import(python_type_module)
1213-
except NotImplementedError:
1214-
self.add_literal_import("typing", "Any")
1215-
column_python_type = "Any"
1227+
return f"{python_type_module}.{python_type_name}"
1228+
except NotImplementedError:
1229+
self.add_literal_import("typing", "Any")
1230+
return "Any"
12161231

1217-
if column.nullable:
1218-
self.add_literal_import("typing", "Optional")
1219-
column_python_type = f"Optional[{column_python_type}]"
1232+
pre, col_type, post = get_type_qualifiers()
1233+
column_python_type = f"{pre}{render_python_type(col_type)}{post}"
12201234
return f"{column_attr.name}: Mapped[{column_python_type}] = {rendered_column}"
12211235

12221236
def render_relationship(self, relationship: RelationshipAttribute) -> str:
@@ -1297,8 +1311,7 @@ def render_join(terms: list[JoinType]) -> str:
12971311

12981312
relationship_type: str
12991313
if relationship.type == RelationshipType.ONE_TO_MANY:
1300-
self.add_literal_import("typing", "List")
1301-
relationship_type = f"List['{relationship.target.name}']"
1314+
relationship_type = f"list['{relationship.target.name}']"
13021315
elif relationship.type in (
13031316
RelationshipType.ONE_TO_ONE,
13041317
RelationshipType.MANY_TO_ONE,
@@ -1310,8 +1323,7 @@ def render_join(terms: list[JoinType]) -> str:
13101323
self.add_literal_import("typing", "Optional")
13111324
relationship_type = f"Optional[{relationship_type}]"
13121325
elif relationship.type == RelationshipType.MANY_TO_MANY:
1313-
self.add_literal_import("typing", "List")
1314-
relationship_type = f"List['{relationship.target.name}']"
1326+
relationship_type = f"list['{relationship.target.name}']"
13151327
else:
13161328
self.add_literal_import("typing", "Any")
13171329
relationship_type = "Any"
@@ -1409,13 +1421,6 @@ def collect_imports_for_model(self, model: Model) -> None:
14091421
if model.relationships:
14101422
self.add_literal_import("sqlmodel", "Relationship")
14111423

1412-
for relationship_attr in model.relationships:
1413-
if relationship_attr.type in (
1414-
RelationshipType.ONE_TO_MANY,
1415-
RelationshipType.MANY_TO_MANY,
1416-
):
1417-
self.add_literal_import("typing", "List")
1418-
14191424
def collect_imports_for_column(self, column: Column[Any]) -> None:
14201425
super().collect_imports_for_column(column)
14211426
try:
@@ -1487,8 +1492,7 @@ def render_relationship(self, relationship: RelationshipAttribute) -> str:
14871492
RelationshipType.ONE_TO_MANY,
14881493
RelationshipType.MANY_TO_MANY,
14891494
):
1490-
self.add_literal_import("typing", "List")
1491-
annotation = f"List[{annotation}]"
1495+
annotation = f"list[{annotation}]"
14921496
else:
14931497
self.add_literal_import("typing", "Optional")
14941498
annotation = f"Optional[{annotation}]"

β€Žtests/test_generator_dataclass.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_onetomany_optional(generator: CodeGenerator) -> None:
101101
validate_code(
102102
generator.generate(),
103103
"""\
104-
from typing import List, Optional
104+
from typing import Optional
105105
106106
from sqlalchemy import ForeignKey, Integer
107107
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
@@ -116,7 +116,7 @@ class SimpleContainers(Base):
116116
117117
id: Mapped[int] = mapped_column(Integer, primary_key=True)
118118
119-
simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \
119+
simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
120120
back_populates='container')
121121
122122
@@ -152,8 +152,6 @@ def test_manytomany(generator: CodeGenerator) -> None:
152152
validate_code(
153153
generator.generate(),
154154
"""\
155-
from typing import List
156-
157155
from sqlalchemy import Column, ForeignKey, Integer, Table
158156
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
159157
mapped_column, relationship
@@ -167,7 +165,7 @@ class SimpleContainers(Base):
167165
168166
id: Mapped[int] = mapped_column(Integer, primary_key=True)
169167
170-
item: Mapped[List['SimpleItems']] = relationship('SimpleItems', \
168+
item: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
171169
secondary='container_items', back_populates='container')
172170
173171
@@ -176,7 +174,7 @@ class SimpleItems(Base):
176174
177175
id: Mapped[int] = mapped_column(Integer, primary_key=True)
178176
179-
container: Mapped[List['SimpleContainers']] = \
177+
container: Mapped[list['SimpleContainers']] = \
180178
relationship('SimpleContainers', secondary='container_items', back_populates='item')
181179
182180
@@ -208,7 +206,7 @@ def test_named_foreign_key_constraints(generator: CodeGenerator) -> None:
208206
validate_code(
209207
generator.generate(),
210208
"""\
211-
from typing import List, Optional
209+
from typing import Optional
212210
213211
from sqlalchemy import ForeignKeyConstraint, Integer
214212
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
@@ -223,7 +221,7 @@ class SimpleContainers(Base):
223221
224222
id: Mapped[int] = mapped_column(Integer, primary_key=True)
225223
226-
simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \
224+
simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
227225
back_populates='container')
228226
229227

0 commit comments

Comments
Β (0)