Skip to content

Commit 1793dab

Browse files
authored
refactor: apply future type hints style (#416)
* refactor: apply future style type hints * chore: put cryptography out of dev dependencies
1 parent 6bdfdfc commit 1793dab

File tree

15 files changed

+92
-309
lines changed

15 files changed

+92
-309
lines changed

Makefile

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
checkfiles = aerich/ tests/ conftest.py
2-
black_opts = -l 100 -t py38
32
py_warn = PYTHONDEVMODE=1
43
MYSQL_HOST ?= "127.0.0.1"
54
MYSQL_PORT ?= 3306
@@ -15,12 +14,12 @@ deps:
1514
@poetry install -E asyncpg -E asyncmy -E toml
1615

1716
_style:
18-
@isort -src $(checkfiles)
19-
@black $(black_opts) $(checkfiles)
17+
@ruff check --fix $(checkfiles)
18+
@ruff format $(checkfiles)
2019
style: deps _style
2120

2221
_check:
23-
@black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
22+
@ruff format --check $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
2423
@ruff check $(checkfiles)
2524
@mypy $(checkfiles)
2625
@bandit -r aerich

aerich/__init__.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from __future__ import annotations
2+
13
import os
24
from pathlib import Path
3-
from typing import TYPE_CHECKING, List, Optional, Type
5+
from typing import TYPE_CHECKING
46

57
from tortoise import Tortoise, generate_schema_for_client
68
from tortoise.exceptions import OperationalError
@@ -21,7 +23,7 @@
2123
)
2224

2325
if TYPE_CHECKING:
24-
from aerich.inspectdb import Inspect # noqa:F401
26+
from aerich.inspectdb import Inspect
2527

2628

2729
class Command:
@@ -51,7 +53,7 @@ async def _upgrade(self, conn, version_file, fake: bool = False) -> None:
5153
content=get_models_describe(self.app),
5254
)
5355

54-
async def upgrade(self, run_in_transaction: bool = True, fake: bool = False) -> List[str]:
56+
async def upgrade(self, run_in_transaction: bool = True, fake: bool = False) -> list[str]:
5557
migrated = []
5658
for version_file in Migrate.get_all_version_files():
5759
try:
@@ -69,8 +71,8 @@ async def upgrade(self, run_in_transaction: bool = True, fake: bool = False) ->
6971
migrated.append(version_file)
7072
return migrated
7173

72-
async def downgrade(self, version: int, delete: bool, fake: bool = False) -> List[str]:
73-
ret: List[str] = []
74+
async def downgrade(self, version: int, delete: bool, fake: bool = False) -> list[str]:
75+
ret: list[str] = []
7476
if version == -1:
7577
specified_version = await Migrate.get_last_version()
7678
else:
@@ -102,23 +104,23 @@ async def downgrade(self, version: int, delete: bool, fake: bool = False) -> Lis
102104
ret.append(file)
103105
return ret
104106

105-
async def heads(self) -> List[str]:
107+
async def heads(self) -> list[str]:
106108
ret = []
107109
versions = Migrate.get_all_version_files()
108110
for version in versions:
109111
if not await Aerich.exists(version=version, app=self.app):
110112
ret.append(version)
111113
return ret
112114

113-
async def history(self) -> List[str]:
115+
async def history(self) -> list[str]:
114116
versions = Migrate.get_all_version_files()
115117
return [version for version in versions]
116118

117-
async def inspectdb(self, tables: Optional[List[str]] = None) -> str:
119+
async def inspectdb(self, tables: list[str] | None = None) -> str:
118120
connection = get_app_connection(self.tortoise_config, self.app)
119121
dialect = connection.schema_generator.DIALECT
120122
if dialect == "mysql":
121-
cls: Type["Inspect"] = InspectMySQL
123+
cls: type[Inspect] = InspectMySQL
122124
elif dialect == "postgres":
123125
cls = InspectPostgres
124126
elif dialect == "sqlite":

aerich/cli.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from __future__ import annotations
2+
13
import os
24
import sys
35
from pathlib import Path
4-
from typing import Dict, List, cast
6+
from typing import cast
57

68
import asyncclick as click
79
from asyncclick import Context, UsageError
@@ -50,7 +52,7 @@ async def cli(ctx: Context, config, app) -> None:
5052
content = config_path.read_text("utf-8")
5153
doc: dict = tomllib.loads(content)
5254
try:
53-
tool = cast(Dict[str, str], doc["tool"]["aerich"])
55+
tool = cast("dict[str, str]", doc["tool"]["aerich"])
5456
location = tool["location"]
5557
tortoise_orm = tool["tortoise_orm"]
5658
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
@@ -274,7 +276,7 @@ async def init_db(ctx: Context, safe: bool) -> None:
274276
required=False,
275277
)
276278
@click.pass_context
277-
async def inspectdb(ctx: Context, table: List[str]) -> None:
279+
async def inspectdb(ctx: Context, table: list[str]) -> None:
278280
command = ctx.obj["command"]
279281
ret = await command.inspectdb(table)
280282
click.secho(ret)

aerich/coder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from __future__ import annotations
2+
13
import base64
24
import json
35
import pickle # nosec: B301,B403
4-
from typing import Any, Union
6+
from typing import Any
57

68
from tortoise.indexes import Index
79

@@ -40,5 +42,5 @@ def encoder(obj: dict) -> str:
4042
return json.dumps(obj, cls=JsonEncoder)
4143

4244

43-
def decoder(obj: Union[str, bytes]) -> Any:
45+
def decoder(obj: str | bytes) -> Any:
4446
return json.loads(obj, object_hook=object_hook)

aerich/ddl/mysql/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from aerich.ddl import BaseDDL
88

99
if TYPE_CHECKING:
10-
from tortoise import Model # noqa:F401
10+
from tortoise import Model
1111

1212

1313
class MysqlDDL(BaseDDL):

aerich/ddl/sqlite/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Type
1+
from __future__ import annotations
22

33
from tortoise import Model
44
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
@@ -13,14 +13,14 @@ class SqliteDDL(BaseDDL):
1313
_ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})'
1414
_DROP_INDEX_TEMPLATE = 'DROP INDEX IF EXISTS "{index_name}"'
1515

16-
def modify_column(self, model: "Type[Model]", field_object: dict, is_pk: bool = True):
16+
def modify_column(self, model: type[Model], field_object: dict, is_pk: bool = True):
1717
raise NotSupportError("Modify column is unsupported in SQLite.")
1818

19-
def alter_column_default(self, model: "Type[Model]", field_describe: dict):
19+
def alter_column_default(self, model: type[Model], field_describe: dict):
2020
raise NotSupportError("Alter column default is unsupported in SQLite.")
2121

22-
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
22+
def alter_column_null(self, model: type[Model], field_describe: dict):
2323
raise NotSupportError("Alter column null is unsupported in SQLite.")
2424

25-
def set_comment(self, model: "Type[Model]", field_describe: dict):
25+
def set_comment(self, model: type[Model], field_describe: dict):
2626
raise NotSupportError("Alter column comment is unsupported in SQLite.")

aerich/inspectdb/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import contextlib
4-
from typing import Any, Callable, Dict, Optional, TypedDict
4+
from typing import Any, Callable, Dict, TypedDict
55

66
from pydantic import BaseModel
77
from tortoise import BaseDBAsyncClient
@@ -17,6 +17,7 @@ class ColumnInfoDict(TypedDict):
1717
comment: str
1818

1919

20+
# TODO: use dict to replace typing.Dict when dropping support for Python3.8
2021
FieldMapDict = Dict[str, Callable[..., str]]
2122

2223

@@ -25,14 +26,14 @@ class Column(BaseModel):
2526
data_type: str
2627
null: bool
2728
default: Any
28-
comment: Optional[str] = None
29+
comment: str | None = None
2930
pk: bool
3031
unique: bool
3132
index: bool
32-
length: Optional[int] = None
33-
extra: Optional[str] = None
34-
decimal_places: Optional[int] = None
35-
max_digits: Optional[int] = None
33+
length: int | None = None
34+
extra: str | None = None
35+
decimal_places: int | None = None
36+
max_digits: int | None = None
3637

3738
def translate(self) -> ColumnInfoDict:
3839
comment = default = length = index = null = pk = ""

aerich/inspectdb/postgres.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
class InspectPostgres(Inspect):
12-
def __init__(self, conn: "BasePostgresClient", tables: list[str] | None = None) -> None:
12+
def __init__(self, conn: BasePostgresClient, tables: list[str] | None = None) -> None:
1313
super().__init__(conn, tables)
1414
self.schema = conn.server_settings.get("schema") or "public"
1515

aerich/migrate.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import importlib
44
import os
5+
from collections.abc import Iterable
56
from datetime import datetime
67
from pathlib import Path
7-
from typing import Iterable, Optional, Union, cast
8+
from typing import cast
89

910
import asyncclick as click
1011
import tortoise
@@ -49,11 +50,11 @@ class Migrate:
4950

5051
ddl: BaseDDL
5152
ddl_class: type[BaseDDL]
52-
_last_version_content: Optional[dict] = None
53+
_last_version_content: dict | None = None
5354
app: str
5455
migrate_location: Path
5556
dialect: str
56-
_db_version: Optional[str] = None
57+
_db_version: str | None = None
5758

5859
@staticmethod
5960
def get_field_by_name(name: str, fields: list[dict]) -> dict:
@@ -79,7 +80,7 @@ def _get_model(cls, model: str) -> type[Model]:
7980
return Tortoise.apps[cls.app].get(model) # type: ignore
8081

8182
@classmethod
82-
async def get_last_version(cls) -> Optional[Aerich]:
83+
async def get_last_version(cls) -> Aerich | None:
8384
try:
8485
return await Aerich.filter(app=cls.app).first()
8586
except OperationalError:
@@ -113,7 +114,7 @@ async def init(cls, config: dict, app: str, location: str) -> None:
113114
await cls._get_db_version(connection)
114115

115116
@classmethod
116-
async def _get_last_version_num(cls) -> Optional[int]:
117+
async def _get_last_version_num(cls) -> int | None:
117118
last_version = await cls.get_last_version()
118119
if not last_version:
119120
return None
@@ -219,7 +220,7 @@ def _add_operator(cls, operator: str, upgrade: bool = True, fk_m2m_index: bool =
219220
cls.downgrade_operators.append(operator)
220221

221222
@classmethod
222-
def _handle_indexes(cls, model: type[Model], indexes: list[Union[tuple[str], Index]]) -> list:
223+
def _handle_indexes(cls, model: type[Model], indexes: list[tuple[str] | Index]) -> list:
223224
if tortoise.__version__ > "0.22.2":
224225
# The min version of tortoise is '0.11.0', so we can compare it by a `>`,
225226
# tortoise>0.22.2 have __eq__/__hash__ with Index class since 313ee76.
@@ -241,8 +242,8 @@ def _eq(self, other) -> bool:
241242
return indexes
242243

243244
@classmethod
244-
def _get_indexes(cls, model, model_describe: dict) -> set[Union[Index, tuple[str, ...]]]:
245-
indexes: set[Union[Index, tuple[str, ...]]] = set()
245+
def _get_indexes(cls, model, model_describe: dict) -> set[Index | tuple[str, ...]]:
246+
indexes: set[Index | tuple[str, ...]] = set()
246247
for x in cls._handle_indexes(model, model_describe.get("indexes", [])):
247248
if isinstance(x, Index):
248249
indexes.add(x)
@@ -686,7 +687,7 @@ def _resolve_fk_fields_name(cls, model: type[Model], fields_name: Iterable[str])
686687

687688
@classmethod
688689
def _drop_index(
689-
cls, model: type[Model], fields_name: Union[Iterable[str], Index], unique=False
690+
cls, model: type[Model], fields_name: Iterable[str] | Index, unique=False
690691
) -> str:
691692
if isinstance(fields_name, Index):
692693
if cls.dialect == "mysql":
@@ -707,7 +708,7 @@ def _drop_index(
707708

708709
@classmethod
709710
def _add_index(
710-
cls, model: type[Model], fields_name: Union[Iterable[str], Index], unique=False
711+
cls, model: type[Model], fields_name: Iterable[str] | Index, unique=False
711712
) -> str:
712713
if isinstance(fields_name, Index):
713714
if cls.dialect == "mysql":

aerich/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import os
55
import re
66
import sys
7+
from collections.abc import Generator
78
from pathlib import Path
89
from types import ModuleType
9-
from typing import Generator, Optional, Union
1010

1111
from asyncclick import BadOptionUsage, ClickException, Context
1212
from dictdiffer import diff
@@ -94,11 +94,11 @@ def get_models_describe(app: str) -> dict:
9494
return ret
9595

9696

97-
def is_default_function(string: str) -> Optional[re.Match]:
97+
def is_default_function(string: str) -> re.Match | None:
9898
return re.match(r"^<function.+>$", str(string or ""))
9999

100100

101-
def import_py_file(file: Union[str, Path]) -> ModuleType:
101+
def import_py_file(file: str | Path) -> ModuleType:
102102
module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
103103
spec = importlib.util.spec_from_file_location(module_name, file)
104104
module = importlib.util.module_from_spec(spec) # type:ignore[arg-type]

conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from __future__ import annotations
2+
13
import asyncio
24
import os
3-
from typing import Generator
5+
from collections.abc import Generator
46

57
import pytest
68
from tortoise import Tortoise, expand_db_url

0 commit comments

Comments
 (0)