Skip to content

Commit e299f8e

Browse files
authored
feat: aerich.Command support async with syntax (#427)
* feat: `aerich.Command` support `async with` syntax * docs: update readme
1 parent db0cf65 commit e299f8e

File tree

3 files changed

+44
-23
lines changed

3 files changed

+44
-23
lines changed

README.md

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -226,14 +226,14 @@ from tortoise import Model, fields
226226

227227

228228
class Test(Model):
229-
date = fields.DateField(null=True, )
230-
datetime = fields.DatetimeField(auto_now=True, )
231-
decimal = fields.DecimalField(max_digits=10, decimal_places=2, )
232-
float = fields.FloatField(null=True, )
233-
id = fields.IntField(pk=True, )
234-
string = fields.CharField(max_length=200, null=True, )
235-
time = fields.TimeField(null=True, )
236-
tinyint = fields.BooleanField(null=True, )
229+
date = fields.DateField(null=True)
230+
datetime = fields.DatetimeField(auto_now=True)
231+
decimal = fields.DecimalField(max_digits=10, decimal_places=2)
232+
float = fields.FloatField(null=True)
233+
id = fields.IntField(primary_key=True)
234+
string = fields.CharField(max_length=200, null=True)
235+
time = fields.TimeField(null=True)
236+
tinyint = fields.BooleanField(null=True)
237237
```
238238

239239
Note that this command is limited and can't infer some fields, such as `IntEnumField`, `ForeignKeyField`, and others.
@@ -243,8 +243,8 @@ Note that this command is limited and can't infer some fields, such as `IntEnumF
243243
```python
244244
tortoise_orm = {
245245
"connections": {
246-
"default": expand_db_url(db_url, True),
247-
"second": expand_db_url(db_url_second, True),
246+
"default": "postgres://postgres_user:[email protected]:5432/db1",
247+
"second": "postgres://postgres_user:[email protected]:5432/db2",
248248
},
249249
"apps": {
250250
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
@@ -253,7 +253,7 @@ tortoise_orm = {
253253
}
254254
```
255255

256-
You only need to specify `aerich.models` in one app, and must specify `--app` when running `aerich migrate` and so on.
256+
You only need to specify `aerich.models` in one app, and must specify `--app` when running `aerich migrate` and so on, e.g. `aerich --app models_second migrate`.
257257

258258
## Restore `aerich` workflow
259259

@@ -273,9 +273,9 @@ You can use `aerich` out of cli by use `Command` class.
273273
```python
274274
from aerich import Command
275275

276-
command = Command(tortoise_config=config, app='models')
277-
await command.init()
278-
await command.migrate('test')
276+
async with Command(tortoise_config=config, app='models') as command:
277+
await command.migrate('test')
278+
await command.upgrade()
279279
```
280280

281281
## Upgrade/Downgrade with `--fake` option

aerich/__init__.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
import os
44
import platform
5+
from contextlib import AbstractAsyncContextManager
56
from pathlib import Path
67
from typing import TYPE_CHECKING
78

89
import tortoise
9-
from tortoise import Tortoise, generate_schema_for_client
10+
from tortoise import Tortoise, connections, generate_schema_for_client
1011
from tortoise.exceptions import OperationalError
1112
from tortoise.transactions import in_transaction
1213
from tortoise.utils import get_schema_sql
@@ -59,10 +60,9 @@ def _init_tortoise_0_24_1_patch():
5960
from tortoise.backends.base.schema_generator import BaseSchemaGenerator, cast, re
6061

6162
def _get_m2m_tables(
62-
self, model: type[Model], table_name: str, safe: bool, models_tables: list[str]
63-
) -> list[str]:
63+
self, model: type[Model], db_table: str, safe: bool, models_tables: list[str]
64+
) -> list[str]: # Copied from tortoise-orm
6465
m2m_tables_for_create = []
65-
db_table = table_name
6666
for m2m_field in model._meta.m2m_fields:
6767
field_object = cast("ManyToManyFieldInstance", model._meta.fields_map[m2m_field])
6868
if field_object._generated or field_object.through in models_tables:
@@ -88,15 +88,15 @@ def _get_m2m_tables(
8888
else:
8989
backward_fk = forward_fk = ""
9090
exists = "IF NOT EXISTS " if safe else ""
91-
table_name = field_object.through
91+
through_table_name = field_object.through
9292
backward_type = self._get_pk_field_sql_type(model._meta.pk)
9393
forward_type = self._get_pk_field_sql_type(field_object.related_model._meta.pk)
9494
comment = ""
9595
if desc := field_object.description:
96-
comment = self._table_comment_generator(table=table_name, comment=desc)
96+
comment = self._table_comment_generator(table=through_table_name, comment=desc)
9797
m2m_create_string = self.M2M_TABLE_TEMPLATE.format(
9898
exists=exists,
99-
table_name=table_name,
99+
table_name=through_table_name,
100100
backward_fk=backward_fk,
101101
forward_fk=forward_fk,
102102
backward_key=backward_key,
@@ -116,7 +116,7 @@ def _get_m2m_tables(
116116
m2m_create_string += self._post_table_hook()
117117
if field_object.create_unique_index:
118118
unique_index_create_sql = self._get_unique_index_sql(
119-
exists, table_name, [backward_key, forward_key]
119+
exists, through_table_name, [backward_key, forward_key]
120120
)
121121
if unique_index_create_sql.endswith(";"):
122122
m2m_create_string += "\n" + unique_index_create_sql
@@ -136,7 +136,7 @@ def _get_m2m_tables(
136136
_init_tortoise_0_24_1_patch()
137137

138138

139-
class Command:
139+
class Command(AbstractAsyncContextManager):
140140
def __init__(
141141
self,
142142
tortoise_config: dict,
@@ -151,6 +151,16 @@ def __init__(
151151
async def init(self) -> None:
152152
await Migrate.init(self.tortoise_config, self.app, self.location)
153153

154+
async def __aenter__(self) -> Command:
155+
await self.init()
156+
return self
157+
158+
async def close(self) -> None:
159+
await connections.close_all()
160+
161+
async def __aexit__(self, *args, **kw) -> None:
162+
await self.close()
163+
154164
async def _upgrade(self, conn, version_file, fake: bool = False) -> None:
155165
file_path = Path(Migrate.migrate_location, version_file)
156166
m = import_py_file(file_path)

tests/test_command.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from aerich import Command
2+
from conftest import tortoise_orm
3+
4+
5+
async def test_command(mocker):
6+
mocker.patch("os.listdir", return_value=[])
7+
async with Command(tortoise_orm) as command:
8+
history = await command.history()
9+
heads = await command.heads()
10+
assert history == []
11+
assert heads == []

0 commit comments

Comments
 (0)