12
12
col ,
13
13
select ,
14
14
)
15
- from sqlmodel .ext .asyncio .session import AsyncSession
16
15
from sqlmodel .sql ._expression_select_cls import SelectBase
17
16
from web3 .types import ABI
18
17
18
+ from .database import db_session
19
19
from .utils import get_md5_abi_hash
20
20
21
21
22
22
class SqlQueryBase :
23
23
24
24
@classmethod
25
- async def get_all (cls , session : AsyncSession ):
26
- result = await session . exec (select (cls ))
27
- return result .all ()
25
+ async def get_all (cls ):
26
+ result = await db_session . execute (select (cls ))
27
+ return result .scalars (). all ()
28
28
29
- async def _save (self , session : AsyncSession ):
30
- session .add (self )
31
- await session .commit ()
29
+ async def _save (self ):
30
+ db_session .add (self )
31
+ await db_session .commit ()
32
32
return self
33
33
34
- async def update (self , session : AsyncSession ):
35
- return await self ._save (session )
34
+ async def update (self ):
35
+ return await self ._save ()
36
36
37
- async def create (self , session : AsyncSession ):
38
- return await self ._save (session )
37
+ async def create (self ):
38
+ return await self ._save ()
39
39
40
40
41
41
class TimeStampedSQLModel (SQLModel ):
@@ -69,33 +69,30 @@ class AbiSource(SqlQueryBase, SQLModel, table=True):
69
69
abis : list ["Abi" ] = Relationship (back_populates = "source" )
70
70
71
71
@classmethod
72
- async def get_or_create (
73
- cls , session : AsyncSession , name : str , url : str
74
- ) -> tuple ["AbiSource" , bool ]:
72
+ async def get_or_create (cls , name : str , url : str ) -> tuple ["AbiSource" , bool ]:
75
73
"""
76
74
Checks if an AbiSource with the given 'name' and 'url' exists.
77
75
If found, returns it with False. If not, creates and returns it with True.
78
76
79
- :param session: The database session.
80
77
:param name: The name to check or create.
81
78
:param url: The URL to check or create.
82
79
:return: A tuple containing the AbiSource object and a boolean indicating
83
80
whether it was created `True` or already exists `False`.
84
81
"""
85
82
query = select (cls ).where (cls .name == name , cls .url == url )
86
- results = await session . exec (query )
87
- if result := results .first ():
83
+ results = await db_session . execute (query )
84
+ if result := results .scalars (). first ():
88
85
return result , False
89
86
else :
90
87
new_item = cls (name = name , url = url )
91
- await new_item .create (session )
88
+ await new_item .create ()
92
89
return new_item , True
93
90
94
91
@classmethod
95
- async def get_abi_source (cls , session : AsyncSession , name : str ):
92
+ async def get_abi_source (cls , name : str ):
96
93
query = select (cls ).where (cls .name == name )
97
- results = await session . exec (query )
98
- if result := results .first ():
94
+ results = await db_session . execute (query )
95
+ if result := results .scalars (). first ():
99
96
return result
100
97
return None
101
98
@@ -113,48 +110,45 @@ class Abi(SqlQueryBase, TimeStampedSQLModel, table=True):
113
110
contracts : list ["Contract" ] = Relationship (back_populates = "abi" )
114
111
115
112
@classmethod
116
- async def get_abis_sorted_by_relevance (
117
- cls , session : AsyncSession
118
- ) -> AsyncIterator [ABI ]:
113
+ async def get_abis_sorted_by_relevance (cls ) -> AsyncIterator [ABI ]:
119
114
"""
120
115
:return: Abi JSON, with the ones with less relevance first
121
116
"""
122
- results = await session .exec (select (cls .abi_json ).order_by (col (cls .relevance )))
123
- for result in results :
117
+ results = await db_session .execute (
118
+ select (cls .abi_json ).order_by (col (cls .relevance ))
119
+ )
120
+ for result in results .scalars ().all ():
124
121
yield cast (ABI , result )
125
122
126
- async def create (self , session ):
123
+ async def create (self ):
127
124
self .abi_hash = get_md5_abi_hash (self .abi_json )
128
- return await self ._save (session )
125
+ return await self ._save ()
129
126
130
127
@classmethod
131
128
async def get_abi (
132
129
cls ,
133
- session : AsyncSession ,
134
130
abi_json : list [dict ] | dict ,
135
131
):
136
132
"""
137
133
Checks if an Abi exists based on the 'abi_json' by its calculated 'abi_hash'.
138
134
If it exists, returns the existing Abi. If not,
139
135
returns None.
140
136
141
- :param session: The database session.
142
137
:param abi_json: The ABI JSON to check.
143
138
:return: The Abi object if it exists, or None if it doesn't.
144
139
"""
145
140
abi_hash = get_md5_abi_hash (abi_json )
146
141
query = select (cls ).where (cls .abi_hash == abi_hash )
147
- result = await session . exec (query )
142
+ result = await db_session . execute (query )
148
143
149
- if existing_abi := result .first ():
144
+ if existing_abi := result .scalars (). first ():
150
145
return existing_abi
151
146
152
147
return None
153
148
154
149
@classmethod
155
150
async def get_or_create_abi (
156
151
cls ,
157
- session : AsyncSession ,
158
152
abi_json : list [dict ] | dict ,
159
153
source_id : int | None ,
160
154
relevance : int | None = 0 ,
@@ -163,18 +157,17 @@ async def get_or_create_abi(
163
157
Checks if an Abi with the given 'abi_json' exists.
164
158
If found, returns it with False. If not, creates and returns it with True.
165
159
166
- :param session: The database session.
167
160
:param abi_json: The ABI JSON to check.
168
161
:param relevance:
169
162
:param source_id:
170
163
:return: A tuple containing the Abi object and a boolean indicating
171
164
whether it was created `True` or already exists `False`.
172
165
"""
173
- if abi := await cls .get_abi (session , abi_json ):
166
+ if abi := await cls .get_abi (abi_json ):
174
167
return abi , False
175
168
else :
176
169
new_item = cls (abi_json = abi_json , relevance = relevance , source_id = source_id )
177
- await new_item .create (session )
170
+ await new_item .create ()
178
171
return new_item , True
179
172
180
173
@@ -230,47 +223,45 @@ def get_contracts_with_abi_query(
230
223
return query
231
224
232
225
@classmethod
233
- async def get_contract (cls , session : AsyncSession , address : bytes , chain_id : int ):
226
+ async def get_contract (cls , address : bytes , chain_id : int ):
234
227
query = (
235
228
select (cls ).where (cls .address == address ).where (cls .chain_id == chain_id )
236
229
)
237
- results = await session . exec (query )
238
- if result := results .first ():
230
+ results = await db_session . execute (query )
231
+ if result := results .scalars (). first ():
239
232
return result
240
233
return None
241
234
242
235
@classmethod
243
236
async def get_or_create (
244
237
cls ,
245
- session : AsyncSession ,
246
238
address : bytes ,
247
239
chain_id : int ,
248
240
** kwargs ,
249
241
) -> tuple ["Contract" , bool ]:
250
242
"""
251
243
Update or create the given params.
252
244
253
- :param session: The database session.
254
245
:param address:
255
246
:param chain_id:
256
247
:param kwargs:
257
248
:return: A tuple containing the Contract object and a boolean indicating
258
249
whether it was created `True` or already exists `False`.
259
250
"""
260
- if contract := await cls .get_contract (session , address , chain_id ):
251
+ if contract := await cls .get_contract (address , chain_id ):
261
252
return contract , False
262
253
else :
263
254
contract = cls (address = address , chain_id = chain_id )
264
255
# Add optional fields
265
256
for key , value in kwargs .items ():
266
257
setattr (contract , key , value )
267
258
268
- await contract .create (session )
259
+ await contract .create ()
269
260
return contract , True
270
261
271
262
@classmethod
272
263
async def get_abi_by_contract_address (
273
- cls , session : AsyncSession , address : bytes , chain_id : int | None
264
+ cls , address : bytes , chain_id : int | None
274
265
) -> ABI | None :
275
266
"""
276
267
:return: Json ABI given the contract `address` and `chain_id`. If `chain_id` is not given,
@@ -287,22 +278,21 @@ async def get_abi_by_contract_address(
287
278
else :
288
279
query = query .order_by (col (cls .chain_id ))
289
280
290
- results = await session . exec (query )
291
- if result := results .first ():
281
+ results = await db_session . execute (query )
282
+ if result := results .scalars (). first ():
292
283
return cast (ABI , result )
293
284
return None
294
285
295
286
@classmethod
296
287
async def get_contracts_without_abi (
297
- cls , session : AsyncSession , max_retries : int = 0
288
+ cls , max_retries : int = 0
298
289
) -> AsyncIterator [Self ]:
299
290
"""
300
291
Fetches contracts without an ABI and fewer retries than max_retries,
301
292
streaming results in batches to reduce memory usage for large datasets.
302
293
More information about streaming results can be found here:
303
294
https://docs.sqlalchemy.org/en/20/core/connections.html#streaming-with-a-dynamically-growing-buffer-using-stream-results
304
295
305
- :param session:
306
296
:param max_retries:
307
297
:return:
308
298
"""
@@ -311,19 +301,18 @@ async def get_contracts_without_abi(
311
301
.where (cls .abi_id == None ) # noqa: E711
312
302
.where (cls .fetch_retries <= max_retries )
313
303
)
314
- result = await session .stream (query )
304
+ result = await db_session .stream (query )
315
305
async for (contract ,) in result :
316
306
yield contract
317
307
318
308
@classmethod
319
- async def get_proxy_contracts (cls , session : AsyncSession ) -> AsyncIterator [Self ]:
309
+ async def get_proxy_contracts (cls ) -> AsyncIterator [Self ]:
320
310
"""
321
311
Return all the contracts with implementation address, so proxy contracts.
322
312
323
- :param session:
324
313
:return:
325
314
"""
326
315
query = select (cls ).where (cls .implementation .isnot (None )) # type: ignore
327
- result = await session .stream (query )
316
+ result = await db_session .stream (query )
328
317
async for (contract ,) in result :
329
318
yield contract
0 commit comments