Skip to content

Commit f9da5b3

Browse files
authored
Refactor database session (#79)
1 parent be533a6 commit f9da5b3

21 files changed

+447
-462
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,13 @@ To open an interactive Python shell within a Docker container and query the data
5353
```
5454
Example usage:
5555
```
56-
In [11]: contracts = await Contract.get_all(session)
56+
In [11]: contracts = await Contract.get_all()
5757
5858
In [12]: contracts[0].address
5959
Out[12]: b'J\xdb\xaa\xc7\xbc#\x9e%\x19\xcb\xfd#\x97\xe0\xf7Z\x1d\xe3U\xc8'
6060
6161
```
62+
Call `await restore_session()` to reopen a new session.
6263

6364
## Contributors
6465
[See contributors](https://github.com/safe-global/safe-decoder-service/graphs/contributors)

app/datasources/db/database.py

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,29 @@
11
import logging
2-
from collections.abc import AsyncGenerator
2+
import uuid
3+
from contextlib import contextmanager
4+
from contextvars import ContextVar
35
from functools import cache, wraps
6+
from typing import Generator
47

5-
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
8+
from sqlalchemy.ext.asyncio import (
9+
AsyncEngine,
10+
async_scoped_session,
11+
async_sessionmaker,
12+
create_async_engine,
13+
)
614
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool
7-
from sqlmodel.ext.asyncio.session import AsyncSession
815

916
from ...config import settings
1017

18+
logger = logging.getLogger(__name__)
19+
1120
pool_classes = {
1221
NullPool.__name__: NullPool,
1322
AsyncAdaptedQueuePool.__name__: AsyncAdaptedQueuePool,
1423
}
1524

25+
_db_session_context: ContextVar[str] = ContextVar("db_session_context")
26+
1627

1728
@cache
1829
def get_engine() -> AsyncEngine:
@@ -35,28 +46,59 @@ def get_engine() -> AsyncEngine:
3546
)
3647

3748

38-
async def get_database_session() -> AsyncGenerator:
39-
async with AsyncSession(get_engine(), expire_on_commit=False) as session:
40-
yield session
49+
@contextmanager
50+
def set_database_session_context(
51+
session_id: str | None = None,
52+
) -> Generator[None, None, None]:
53+
"""
54+
Set session ContextVar, at the end it will be removed.
55+
This context is designed to be used with `async_scoped_session` to define a context scope.
56+
57+
:param session_id:
58+
:return:
59+
"""
60+
_session_id: str = session_id or str(uuid.uuid4())
61+
logger.debug(f"Storing db_session context: {_session_id}")
62+
token = _db_session_context.set(_session_id)
63+
try:
64+
yield
65+
finally:
66+
logger.debug(f"Removing db_session context: {_session_id}")
67+
_db_session_context.reset(token)
4168

4269

43-
def database_session(func):
70+
def _get_database_session_context() -> str:
4471
"""
45-
Decorator that creates a new database session for the given function
72+
Get the database session id from the ContextVar.
73+
Used as a scope function on `async_scoped_session`.
4674
47-
:param func:
48-
:return:
75+
:return: session_id for the current context
76+
"""
77+
return _db_session_context.get()
78+
79+
80+
def db_session_context(func):
81+
"""
82+
Wrap the decorated function in the `set_database_session_context` context.
83+
Decorated function will share the same database session.
84+
Remove the session at the end of the context.
4985
"""
5086

5187
@wraps(func)
5288
async def wrapper(*args, **kwargs):
53-
async with AsyncSession(get_engine(), expire_on_commit=False) as session:
89+
with set_database_session_context():
5490
try:
55-
return await func(*args, **kwargs, session=session)
56-
except Exception as e:
57-
# Rollback errors
58-
await session.rollback()
59-
logging.error(f"Error occurred: {e}")
60-
raise
91+
return await func(*args, **kwargs)
92+
finally:
93+
logger.debug(
94+
f"Removing session context: {_get_database_session_context()}"
95+
)
96+
await db_session.remove()
6197

6298
return wrapper
99+
100+
101+
async_session_factory = async_sessionmaker(get_engine(), expire_on_commit=False)
102+
db_session = async_scoped_session(
103+
session_factory=async_session_factory, scopefunc=_get_database_session_context
104+
)

app/datasources/db/models.py

Lines changed: 41 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,30 @@
1212
col,
1313
select,
1414
)
15-
from sqlmodel.ext.asyncio.session import AsyncSession
1615
from sqlmodel.sql._expression_select_cls import SelectBase
1716
from web3.types import ABI
1817

18+
from .database import db_session
1919
from .utils import get_md5_abi_hash
2020

2121

2222
class SqlQueryBase:
2323

2424
@classmethod
25-
async def get_all(cls, session: AsyncSession):
26-
result = await session.exec(select(cls))
27-
return result.all()
25+
async def get_all(cls):
26+
result = await db_session.execute(select(cls))
27+
return result.scalars().all()
2828

29-
async def _save(self, session: AsyncSession):
30-
session.add(self)
31-
await session.commit()
29+
async def _save(self):
30+
db_session.add(self)
31+
await db_session.commit()
3232
return self
3333

34-
async def update(self, session: AsyncSession):
35-
return await self._save(session)
34+
async def update(self):
35+
return await self._save()
3636

37-
async def create(self, session: AsyncSession):
38-
return await self._save(session)
37+
async def create(self):
38+
return await self._save()
3939

4040

4141
class TimeStampedSQLModel(SQLModel):
@@ -69,33 +69,30 @@ class AbiSource(SqlQueryBase, SQLModel, table=True):
6969
abis: list["Abi"] = Relationship(back_populates="source")
7070

7171
@classmethod
72-
async def get_or_create(
73-
cls, session: AsyncSession, name: str, url: str
74-
) -> tuple["AbiSource", bool]:
72+
async def get_or_create(cls, name: str, url: str) -> tuple["AbiSource", bool]:
7573
"""
7674
Checks if an AbiSource with the given 'name' and 'url' exists.
7775
If found, returns it with False. If not, creates and returns it with True.
7876
79-
:param session: The database session.
8077
:param name: The name to check or create.
8178
:param url: The URL to check or create.
8279
:return: A tuple containing the AbiSource object and a boolean indicating
8380
whether it was created `True` or already exists `False`.
8481
"""
8582
query = select(cls).where(cls.name == name, cls.url == url)
86-
results = await session.exec(query)
87-
if result := results.first():
83+
results = await db_session.execute(query)
84+
if result := results.scalars().first():
8885
return result, False
8986
else:
9087
new_item = cls(name=name, url=url)
91-
await new_item.create(session)
88+
await new_item.create()
9289
return new_item, True
9390

9491
@classmethod
95-
async def get_abi_source(cls, session: AsyncSession, name: str):
92+
async def get_abi_source(cls, name: str):
9693
query = select(cls).where(cls.name == name)
97-
results = await session.exec(query)
98-
if result := results.first():
94+
results = await db_session.execute(query)
95+
if result := results.scalars().first():
9996
return result
10097
return None
10198

@@ -113,48 +110,45 @@ class Abi(SqlQueryBase, TimeStampedSQLModel, table=True):
113110
contracts: list["Contract"] = Relationship(back_populates="abi")
114111

115112
@classmethod
116-
async def get_abis_sorted_by_relevance(
117-
cls, session: AsyncSession
118-
) -> AsyncIterator[ABI]:
113+
async def get_abis_sorted_by_relevance(cls) -> AsyncIterator[ABI]:
119114
"""
120115
:return: Abi JSON, with the ones with less relevance first
121116
"""
122-
results = await session.exec(select(cls.abi_json).order_by(col(cls.relevance)))
123-
for result in results:
117+
results = await db_session.execute(
118+
select(cls.abi_json).order_by(col(cls.relevance))
119+
)
120+
for result in results.scalars().all():
124121
yield cast(ABI, result)
125122

126-
async def create(self, session):
123+
async def create(self):
127124
self.abi_hash = get_md5_abi_hash(self.abi_json)
128-
return await self._save(session)
125+
return await self._save()
129126

130127
@classmethod
131128
async def get_abi(
132129
cls,
133-
session: AsyncSession,
134130
abi_json: list[dict] | dict,
135131
):
136132
"""
137133
Checks if an Abi exists based on the 'abi_json' by its calculated 'abi_hash'.
138134
If it exists, returns the existing Abi. If not,
139135
returns None.
140136
141-
:param session: The database session.
142137
:param abi_json: The ABI JSON to check.
143138
:return: The Abi object if it exists, or None if it doesn't.
144139
"""
145140
abi_hash = get_md5_abi_hash(abi_json)
146141
query = select(cls).where(cls.abi_hash == abi_hash)
147-
result = await session.exec(query)
142+
result = await db_session.execute(query)
148143

149-
if existing_abi := result.first():
144+
if existing_abi := result.scalars().first():
150145
return existing_abi
151146

152147
return None
153148

154149
@classmethod
155150
async def get_or_create_abi(
156151
cls,
157-
session: AsyncSession,
158152
abi_json: list[dict] | dict,
159153
source_id: int | None,
160154
relevance: int | None = 0,
@@ -163,18 +157,17 @@ async def get_or_create_abi(
163157
Checks if an Abi with the given 'abi_json' exists.
164158
If found, returns it with False. If not, creates and returns it with True.
165159
166-
:param session: The database session.
167160
:param abi_json: The ABI JSON to check.
168161
:param relevance:
169162
:param source_id:
170163
:return: A tuple containing the Abi object and a boolean indicating
171164
whether it was created `True` or already exists `False`.
172165
"""
173-
if abi := await cls.get_abi(session, abi_json):
166+
if abi := await cls.get_abi(abi_json):
174167
return abi, False
175168
else:
176169
new_item = cls(abi_json=abi_json, relevance=relevance, source_id=source_id)
177-
await new_item.create(session)
170+
await new_item.create()
178171
return new_item, True
179172

180173

@@ -230,47 +223,45 @@ def get_contracts_with_abi_query(
230223
return query
231224

232225
@classmethod
233-
async def get_contract(cls, session: AsyncSession, address: bytes, chain_id: int):
226+
async def get_contract(cls, address: bytes, chain_id: int):
234227
query = (
235228
select(cls).where(cls.address == address).where(cls.chain_id == chain_id)
236229
)
237-
results = await session.exec(query)
238-
if result := results.first():
230+
results = await db_session.execute(query)
231+
if result := results.scalars().first():
239232
return result
240233
return None
241234

242235
@classmethod
243236
async def get_or_create(
244237
cls,
245-
session: AsyncSession,
246238
address: bytes,
247239
chain_id: int,
248240
**kwargs,
249241
) -> tuple["Contract", bool]:
250242
"""
251243
Update or create the given params.
252244
253-
:param session: The database session.
254245
:param address:
255246
:param chain_id:
256247
:param kwargs:
257248
:return: A tuple containing the Contract object and a boolean indicating
258249
whether it was created `True` or already exists `False`.
259250
"""
260-
if contract := await cls.get_contract(session, address, chain_id):
251+
if contract := await cls.get_contract(address, chain_id):
261252
return contract, False
262253
else:
263254
contract = cls(address=address, chain_id=chain_id)
264255
# Add optional fields
265256
for key, value in kwargs.items():
266257
setattr(contract, key, value)
267258

268-
await contract.create(session)
259+
await contract.create()
269260
return contract, True
270261

271262
@classmethod
272263
async def get_abi_by_contract_address(
273-
cls, session: AsyncSession, address: bytes, chain_id: int | None
264+
cls, address: bytes, chain_id: int | None
274265
) -> ABI | None:
275266
"""
276267
:return: Json ABI given the contract `address` and `chain_id`. If `chain_id` is not given,
@@ -287,22 +278,21 @@ async def get_abi_by_contract_address(
287278
else:
288279
query = query.order_by(col(cls.chain_id))
289280

290-
results = await session.exec(query)
291-
if result := results.first():
281+
results = await db_session.execute(query)
282+
if result := results.scalars().first():
292283
return cast(ABI, result)
293284
return None
294285

295286
@classmethod
296287
async def get_contracts_without_abi(
297-
cls, session: AsyncSession, max_retries: int = 0
288+
cls, max_retries: int = 0
298289
) -> AsyncIterator[Self]:
299290
"""
300291
Fetches contracts without an ABI and fewer retries than max_retries,
301292
streaming results in batches to reduce memory usage for large datasets.
302293
More information about streaming results can be found here:
303294
https://docs.sqlalchemy.org/en/20/core/connections.html#streaming-with-a-dynamically-growing-buffer-using-stream-results
304295
305-
:param session:
306296
:param max_retries:
307297
:return:
308298
"""
@@ -311,19 +301,18 @@ async def get_contracts_without_abi(
311301
.where(cls.abi_id == None) # noqa: E711
312302
.where(cls.fetch_retries <= max_retries)
313303
)
314-
result = await session.stream(query)
304+
result = await db_session.stream(query)
315305
async for (contract,) in result:
316306
yield contract
317307

318308
@classmethod
319-
async def get_proxy_contracts(cls, session: AsyncSession) -> AsyncIterator[Self]:
309+
async def get_proxy_contracts(cls) -> AsyncIterator[Self]:
320310
"""
321311
Return all the contracts with implementation address, so proxy contracts.
322312
323-
:param session:
324313
:return:
325314
"""
326315
query = select(cls).where(cls.implementation.isnot(None)) # type: ignore
327-
result = await session.stream(query)
316+
result = await db_session.stream(query)
328317
async for (contract,) in result:
329318
yield contract

0 commit comments

Comments
 (0)