2
2
3
3
import os
4
4
import platform
5
+ from contextlib import AbstractAsyncContextManager
5
6
from pathlib import Path
6
7
from typing import TYPE_CHECKING
7
8
8
9
import tortoise
9
- from tortoise import Tortoise , generate_schema_for_client
10
+ from tortoise import Tortoise , connections , generate_schema_for_client
10
11
from tortoise .exceptions import OperationalError
11
12
from tortoise .transactions import in_transaction
12
13
from tortoise .utils import get_schema_sql
@@ -59,10 +60,9 @@ def _init_tortoise_0_24_1_patch():
59
60
from tortoise .backends .base .schema_generator import BaseSchemaGenerator , cast , re
60
61
61
62
def _get_m2m_tables (
62
- self , model : type [Model ], table_name : str , safe : bool , models_tables : list [str ]
63
- ) -> list [str ]:
63
+ self , model : type [Model ], db_table : str , safe : bool , models_tables : list [str ]
64
+ ) -> list [str ]: # Copied from tortoise-orm
64
65
m2m_tables_for_create = []
65
- db_table = table_name
66
66
for m2m_field in model ._meta .m2m_fields :
67
67
field_object = cast ("ManyToManyFieldInstance" , model ._meta .fields_map [m2m_field ])
68
68
if field_object ._generated or field_object .through in models_tables :
@@ -88,15 +88,15 @@ def _get_m2m_tables(
88
88
else :
89
89
backward_fk = forward_fk = ""
90
90
exists = "IF NOT EXISTS " if safe else ""
91
- table_name = field_object .through
91
+ through_table_name = field_object .through
92
92
backward_type = self ._get_pk_field_sql_type (model ._meta .pk )
93
93
forward_type = self ._get_pk_field_sql_type (field_object .related_model ._meta .pk )
94
94
comment = ""
95
95
if desc := field_object .description :
96
- comment = self ._table_comment_generator (table = table_name , comment = desc )
96
+ comment = self ._table_comment_generator (table = through_table_name , comment = desc )
97
97
m2m_create_string = self .M2M_TABLE_TEMPLATE .format (
98
98
exists = exists ,
99
- table_name = table_name ,
99
+ table_name = through_table_name ,
100
100
backward_fk = backward_fk ,
101
101
forward_fk = forward_fk ,
102
102
backward_key = backward_key ,
@@ -116,7 +116,7 @@ def _get_m2m_tables(
116
116
m2m_create_string += self ._post_table_hook ()
117
117
if field_object .create_unique_index :
118
118
unique_index_create_sql = self ._get_unique_index_sql (
119
- exists , table_name , [backward_key , forward_key ]
119
+ exists , through_table_name , [backward_key , forward_key ]
120
120
)
121
121
if unique_index_create_sql .endswith (";" ):
122
122
m2m_create_string += "\n " + unique_index_create_sql
@@ -136,7 +136,7 @@ def _get_m2m_tables(
136
136
_init_tortoise_0_24_1_patch ()
137
137
138
138
139
- class Command :
139
+ class Command ( AbstractAsyncContextManager ) :
140
140
def __init__ (
141
141
self ,
142
142
tortoise_config : dict ,
@@ -151,6 +151,16 @@ def __init__(
151
151
async def init (self ) -> None :
152
152
await Migrate .init (self .tortoise_config , self .app , self .location )
153
153
154
+ async def __aenter__ (self ) -> Command :
155
+ await self .init ()
156
+ return self
157
+
158
+ async def close (self ) -> None :
159
+ await connections .close_all ()
160
+
161
+ async def __aexit__ (self , * args , ** kw ) -> None :
162
+ await self .close ()
163
+
154
164
async def _upgrade (self , conn , version_file , fake : bool = False ) -> None :
155
165
file_path = Path (Migrate .migrate_location , version_file )
156
166
m = import_py_file (file_path )
0 commit comments